diff --git a/doc/source/custom_directives.py b/doc/source/custom_directives.py index 0759f7d5fcde..d5686e8c5bf3 100644 --- a/doc/source/custom_directives.py +++ b/doc/source/custom_directives.py @@ -91,6 +91,8 @@ def update_context(app, pagename, templatename, context, doctree): "ax.service.ax_client", "ConfigSpace", "dask.distributed", + "datasets", + "datasets.iterable_dataset", "gym", "gym.spaces", "horovod", @@ -128,6 +130,18 @@ def update_context(app, pagename, templatename, context, doctree): "tensorflow", "tensorflow.contrib", "tensorflow.contrib.all_reduce", + "transformers", + "transformers.modeling_utils", + "transformers.models", + "transformers.models.auto", + "transformers.pipelines", + "transformers.pipelines.table_question_answering", + "transformers.trainer", + "transformers.training_args", + "transformers.trainer_callback", + "transformers.utils", + "transformers.utils.logging", + "transformers.utils.versions", "tree", "tensorflow.contrib.all_reduce.python", "tensorflow.contrib.layers", diff --git a/doc/source/ray-air/getting-started.rst b/doc/source/ray-air/getting-started.rst index 020e956ba558..034171c5a8ac 100644 --- a/doc/source/ray-air/getting-started.rst +++ b/doc/source/ray-air/getting-started.rst @@ -60,6 +60,10 @@ Trainer :members: :show-inheritance: +.. automodule:: ray.ml.train.integrations.huggingface + :members: + :show-inheritance: + .. automodule:: ray.ml.train.integrations.sklearn :members: :show-inheritance: @@ -112,6 +116,10 @@ Predictors :members: :show-inheritance: +.. automodule:: ray.ml.predictors.integrations.huggingface + :members: + :show-inheritance: + .. _air-serve-integration: Serving diff --git a/python/ray/data/dataset.py b/python/ray/data/dataset.py index f5fffa2d4778..7e4e7da41a65 100644 --- a/python/ray/data/dataset.py +++ b/python/ray/data/dataset.py @@ -25,6 +25,7 @@ import pyspark import torch import tensorflow as tf + import torch.utils.data from ray.data.dataset_pipeline import DatasetPipeline from ray.data.grouped_dataset import GroupedDataset @@ -310,7 +311,8 @@ def transform(block: Block) -> Iterable[Block]: ): raise ValueError( "The map batches UDF returned the value " - f"{applied}, which is not allowed. " + f"{applied} of type {type(applied)}, " + "which is not allowed. " "The return type must be either list, " "pandas.DataFrame, or pyarrow.Table" ) @@ -2124,6 +2126,7 @@ def to_torch( prefetch_blocks: int = 0, drop_last: bool = False, unsqueeze_label_tensor: bool = True, + unsqueeze_feature_tensors: bool = True, ) -> "torch.utils.data.IterableDataset": """Return a Torch IterableDataset over this dataset. @@ -2197,6 +2200,10 @@ def to_torch( be left as is, that is (N, ). In general, regression loss functions expect an unsqueezed tensor, while classification loss functions expect a squeezed one. Defaults to True. + unsqueeze_feature_tensors (bool): If set to True, the features tensors + will be unsqueezed (reshaped to (N, 1)) before being concatenated into + the final features tensor. Otherwise, they will be left as is, that is + (N, ). Defaults to True. Returns: A torch IterableDataset. @@ -2248,10 +2255,13 @@ def make_generator(): drop_last=drop_last, ): if label_column: - label_vals = batch.pop(label_column).values - label_tensor = torch.as_tensor(label_vals, dtype=label_column_dtype) - if unsqueeze_label_tensor: - label_tensor = label_tensor.view(-1, 1) + label_tensor = convert_pandas_to_torch_tensor( + batch, + [label_column], + label_column_dtype, + unsqueeze=unsqueeze_label_tensor, + ) + batch.pop(label_column) else: label_tensor = None @@ -2263,6 +2273,7 @@ def make_generator(): feature_column_dtypes[key] if isinstance(feature_column_dtypes, dict) else feature_column_dtypes, + unsqueeze=unsqueeze_feature_tensors, ) for key in feature_columns } @@ -2271,6 +2282,7 @@ def make_generator(): batch, columns=feature_columns, column_dtypes=feature_column_dtypes, + unsqueeze=unsqueeze_feature_tensors, ) yield (features_tensor, label_tensor) diff --git a/python/ray/ml/BUILD b/python/ray/ml/BUILD index 84d6f72d53d0..ea0a3a15864c 100644 --- a/python/ray/ml/BUILD +++ b/python/ray/ml/BUILD @@ -21,6 +21,15 @@ py_test( args = ["--use-gpu", "--num-workers=2", "--epochs=1", "--dataset=fake"] ) +py_test ( + name = "huggingface_basic_language_modeling_example", + size = "medium", + srcs = ["examples/huggingface/huggingface_basic_language_modeling_example.py"], + args = ["--smoke-test", "--num-epochs 3"], + tags = ["team:ml", "exclusive"], + deps = [":ml_lib"] +) + py_test ( name = "lightgbm_example", size = "medium", @@ -182,6 +191,22 @@ py_test( deps = [":ml_lib"] ) +py_test( + name = "test_huggingface_predictor", + size = "small", + srcs = ["tests/test_huggingface_predictor.py"], + tags = ["team:ml", "exclusive"], + deps = [":ml_lib"] +) + +py_test( + name = "test_huggingface_trainer", + size = "medium", + srcs = ["tests/test_huggingface_trainer.py"], + tags = ["team:ml", "exclusive"], + deps = [":ml_lib"] +) + py_test( name = "test_lightgbm_predictor", size = "small", diff --git a/python/ray/ml/examples/huggingface/__init__.py b/python/ray/ml/examples/huggingface/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py b/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py new file mode 100644 index 000000000000..178b7b202a4e --- /dev/null +++ b/python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py @@ -0,0 +1,213 @@ +# Based on +# huggingface/notebooks/examples/language_modeling_from_scratch.ipynb + +import argparse +import tempfile +from datasets import load_dataset +from transformers import ( + AutoTokenizer, + AutoConfig, + AutoModelForCausalLM, + Trainer, + TrainingArguments, +) + +import pandas as pd +import torch + +import ray +import ray.data +from ray.ml.train.integrations.huggingface import HuggingFaceTrainer +from ray.ml.predictors.integrations.huggingface import HuggingFacePredictor +from ray.ml.batch_predictor import BatchPredictor + + +def main( + model_checkpoint="gpt2", + tokenizer_checkpoint="sgugger/gpt2-like-tokenizer", + dataset_name="wikitext-2-raw-v1", + dataset_path="wikitext", + num_epochs=5, + num_workers=2, + use_gpu=False, + smoke_test=False, +): + block_size = 128 + + # Uncomment the following if the maximum length the model was + # pretrained with can fit in your memory. + # block_size = tokenizer.model_max_length + + # Run this as a remote function to avoid downloading on the driver + @ray.remote + def get_dataset(): + datasets = load_dataset(dataset_path, dataset_name) + tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint) + + def tokenize_function(examples): + return tokenizer(examples["text"]) + + tokenized_datasets = datasets.map( + tokenize_function, batched=True, num_proc=1, remove_columns=["text"] + ) + + def group_texts(examples): + # Concatenate all texts. + concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} + total_length = len(concatenated_examples[list(examples.keys())[0]]) + # We drop the small remainder. We could add padding if the model supported + # it instead of this drop. You can customize this part to your needs. + total_length = (total_length // block_size) * block_size + # Split by chunks of max_len. + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } + result["labels"] = result["input_ids"].copy() + return result + + lm_datasets = tokenized_datasets.map( + group_texts, + batched=True, + batch_size=1000, + num_proc=1, + ) + ray_train = ray.data.from_arrow(lm_datasets["train"]._data.table) + ray_validation = ray.data.from_arrow(lm_datasets["validation"]._data.table) + return ray_train, ray_validation + + ray_train, ray_validation = ray.get(get_dataset.remote()) + + def train_function(train_dataset, eval_dataset=None, **config): + model_config = AutoConfig.from_pretrained(model_checkpoint) + model = AutoModelForCausalLM.from_config(model_config) + print("Initializing TrainingArguments...") + # The checkpoints will be moved to Ray Tune results + # directory automatically + training_dir = tempfile.mkdtemp() + training_args = TrainingArguments( + training_dir, + evaluation_strategy="epoch", + num_train_epochs=num_epochs, + learning_rate=2e-5, + weight_decay=0.01, + disable_tqdm=True, + save_strategy="epoch", + # Required to avoid an exception + no_cuda=not torch.cuda.is_available(), + ) + print("Initializing Trainer...") + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + print("Trainer initialized! Starting training...") + return trainer + + if smoke_test: + ray_train = ray_train.limit(16) + ray_validation = ray_validation.limit(8) + + trainer = HuggingFaceTrainer( + trainer_init_per_worker=train_function, + scaling_config={"num_workers": num_workers, "use_gpu": use_gpu}, + datasets={"train": ray_train, "evaluation": ray_validation}, + ) + results = trainer.fit() + print(results.metrics) + + tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint) + prompt = ["My text: Complete me..."] + predictor = BatchPredictor.from_checkpoint( + results.checkpoint, + HuggingFacePredictor, + task="text-generation", + tokenizer=tokenizer, + ) + data = ray.data.from_pandas(pd.DataFrame(prompt, columns=["prompt"])) + prediction = predictor.predict(data, num_gpus_per_worker=int(use_gpu)) + prediction = prediction.to_pandas().iloc[0]["generated_text"] + + print(f"Generated text for prompt '{prompt}': '{prediction}'") + + +if __name__ == "__main__": + # Training settings + parser = argparse.ArgumentParser( + description="Language modelling from scratch with HuggingFaceTrainer Example", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--model-checkpoint", + type=str, + default="gpt2", + help="Model checkpoint name to download from HF hub", + ) + parser.add_argument( + "--tokenizer-checkpoint", + type=str, + default="sgugger/gpt2-like-tokenizer", + help="Tokenizer checkpoint name to download from HF hub", + ) + parser.add_argument( + "--dataset-name", + type=str, + default="wikitext-2-raw-v1", + help="Dataset name to download from HF hub", + ) + parser.add_argument( + "--dataset-path", + type=str, + default="wikitext", + help="Path on the head node to save the dataset to", + ) + parser.add_argument( + "--num-epochs", + type=int, + default=5, + help="number of epochs to train (default: 5)", + ) + parser.add_argument( + "--use-gpu", action="store_true", default=False, help="enables CUDA training" + ) + parser.add_argument( + "--num-workers", + type=int, + default=2, + help="Number of Ray workers to use for training.", + ) + parser.add_argument( + "--smoke-test", + action="store_true", + default=False, + help="Limit dataset size to finish quickly for testing", + ) + parser.add_argument( + "--address", + required=False, + type=str, + default=None, + help="Address of Ray cluster.", + ) + + args = parser.parse_args() + + # Requires at least torch 1.11 to pass + runtime_env = {"pip": ["torch==1.11.0"]} + if args.address: + ray.init(args.address, runtime_env=runtime_env) + else: + ray.init(runtime_env=runtime_env) + + main( + model_checkpoint=args.model_checkpoint, + tokenizer_checkpoint=args.tokenizer_checkpoint, + dataset_name=args.dataset_name, + dataset_path=args.dataset_path, + num_epochs=args.num_epochs, + num_workers=args.num_workers, + use_gpu=args.use_gpu, + smoke_test=args.smoke_test, + ) diff --git a/python/ray/ml/predictors/integrations/huggingface/__init__.py b/python/ray/ml/predictors/integrations/huggingface/__init__.py new file mode 100644 index 000000000000..617063387e1d --- /dev/null +++ b/python/ray/ml/predictors/integrations/huggingface/__init__.py @@ -0,0 +1,5 @@ +from ray.ml.predictors.integrations.huggingface.huggingface_predictor import ( + HuggingFacePredictor, +) + +__all__ = ["HuggingFacePredictor"] diff --git a/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py new file mode 100644 index 000000000000..38933cf22a5e --- /dev/null +++ b/python/ray/ml/predictors/integrations/huggingface/huggingface_predictor.py @@ -0,0 +1,170 @@ +import os +from typing import Optional, Type, Union, List + +import numpy as np +import pandas as pd +from ray.ml.constants import PREPROCESSOR_KEY + +from transformers.pipelines import Pipeline, pipeline as pipeline_factory +from transformers.pipelines.table_question_answering import ( + TableQuestionAnsweringPipeline, +) + +import ray.cloudpickle as cpickle +from ray.ml.predictor import DataBatchType, Predictor +from ray.ml.preprocessor import Preprocessor +from ray.ml.checkpoint import Checkpoint + + +class HuggingFacePredictor(Predictor): + """A predictor for HuggingFace Transformers PyTorch models. + + This predictor uses Transformers Pipelines for inference. + + Args: + pipeline: The Transformers pipeline to use for inference. + preprocessor: A preprocessor used to transform data batches prior + to prediction. + """ + + def __init__( + self, + pipeline: Optional[Pipeline] = None, + preprocessor: Optional[Preprocessor] = None, + ): + self.pipeline = pipeline + self.preprocessor = preprocessor + + @classmethod + def from_checkpoint( + cls, + checkpoint: Checkpoint, + *, + pipeline: Optional[Type[Pipeline]] = None, + **pipeline_kwargs, + ) -> "HuggingFacePredictor": + """Instantiate the predictor from a Checkpoint. + + The checkpoint is expected to be a result of ``HuggingFaceTrainer``. + + Args: + checkpoint: The checkpoint to load the model and + preprocessor from. It is expected to be from the result of a + ``HuggingFaceTrainer`` run. + pipeline: A ``transformers.pipelines.Pipeline`` class to use. + If not specified, will use the ``pipeline`` abstraction + wrapper. + **pipeline_kwargs: Any kwargs to pass to the pipeline + initialization. If ``pipeline`` is None, this must contain + the 'task' argument. Cannot contain 'model'. + """ + if not pipeline and "task" not in pipeline_kwargs: + raise ValueError( + "If `pipeline` is not specified, 'task' must be passed as a kwarg." + ) + pipeline = pipeline or pipeline_factory + with checkpoint.as_directory() as checkpoint_path: + preprocessor_path = os.path.join(checkpoint_path, PREPROCESSOR_KEY) + if os.path.exists(preprocessor_path): + with open(preprocessor_path, "rb") as f: + preprocessor = cpickle.load(f) + else: + preprocessor = None + pipeline = pipeline(model=checkpoint_path, **pipeline_kwargs) + return HuggingFacePredictor( + pipeline=pipeline, + preprocessor=preprocessor, + ) + + def _predict( + self, data: Union[list, pd.DataFrame], **pipeline_call_kwargs + ) -> pd.DataFrame: + ret = self.pipeline(data, **pipeline_call_kwargs) + # Remove unnecessary lists + try: + new_ret = [x[0] if isinstance(x, list) and len(x) == 1 else x for x in ret] + df = pd.DataFrame(new_ret) + except Exception: + # if we fail for any reason, just give up + df = pd.DataFrame(ret) + df.columns = [str(col) for col in df.columns] + return df + + def _convert_data_for_pipeline( + self, data: pd.DataFrame + ) -> Union[list, pd.DataFrame]: + """Convert the data into a format accepted by the pipeline. + + In most cases, this format is a list of strings.""" + # Special case + if isinstance(self.pipeline, TableQuestionAnsweringPipeline): + return data + # Otherwise, a list of columns as lists + columns = [data[col].to_list() for col in data.columns] + # Flatten if it's only one column + if len(columns) == 1: + columns = columns[0] + return columns + + def predict( + self, + data: DataBatchType, + feature_columns: Optional[List[str]] = None, + **pipeline_call_kwargs, + ) -> DataBatchType: + """Run inference on data batch. + + The data is converted into a list (unless ``pipeline`` is a + ``TableQuestionAnsweringPipeline``) and passed to the ``pipeline`` + object. + + Args: + data: A batch of input data. Either a pandas DataFrame or numpy + array. + feature_columns: The names or indices of the columns in the + data to use as features to predict on. If None, use all + columns. + **pipeline_call_kwargs: additional kwargs to pass to the + ``pipeline`` object. + + Examples: + + .. code-block:: python + + import pandas as pd + from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + from transformers.pipelines import pipeline + from ray.ml.predictors.integrations.huggingface import HuggingFacePredictor + + model_checkpoint = "gpt2" + tokenizer_checkpoint = "sgugger/gpt2-like-tokenizer" + tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint) + + model_config = AutoConfig.from_pretrained(model_checkpoint) + model = AutoModelForCausalLM.from_config(model_config) + predictor = HuggingFacePredictor( + pipeline=pipeline( + task="text-generation", model=model, tokenizer=tokenizer + ) + ) + + prompts = pd.DataFrame( + ["Complete me", "And me", "Please complete"], columns=["sentences"] + ) + predictions = predictor.predict(prompts) + + + Returns: + DataBatchType: Prediction result. + """ + if self.preprocessor: + data = self.preprocessor.transform_batch(data) + + if isinstance(data, np.ndarray): + # If numpy array, then convert to pandas dataframe. + data = pd.DataFrame(data) + + data = data[feature_columns] if feature_columns else data + + data = self._convert_data_for_pipeline(data) + return self._predict(data, **pipeline_call_kwargs) diff --git a/python/ray/ml/predictors/integrations/torch/torch_predictor.py b/python/ray/ml/predictors/integrations/torch/torch_predictor.py index a9e8eedffa39..c03493cbca8e 100644 --- a/python/ray/ml/predictors/integrations/torch/torch_predictor.py +++ b/python/ray/ml/predictors/integrations/torch/torch_predictor.py @@ -24,7 +24,6 @@ def __init__( self, model: torch.nn.Module, preprocessor: Optional[Preprocessor] = None ): self.model = model - self.model.eval() self.preprocessor = preprocessor @classmethod @@ -56,6 +55,45 @@ def from_checkpoint( ) return TorchPredictor(model=model, preprocessor=preprocessor) + # parity with Datset.to_torch + def _convert_to_tensor( + self, + data: pd.DataFrame, + feature_columns: Optional[ + Union[List[str], List[List[str]], List[int], List[List[int]]] + ] = None, + dtypes: Optional[torch.dtype] = None, + unsqueeze: bool = True, + ) -> torch.Tensor: + """Handle conversion of data to tensor. + + Same arguments as in ``convert_pandas_to_torch_tensor``.""" + # TODO(amog): Add `_convert_numpy_to_torch_tensor to use based on input type. + # Reduce conversion cost if input is in Numpy + if isinstance(feature_columns, dict): + features_tensor = { + key: convert_pandas_to_torch_tensor( + data, + feature_columns[key], + dtypes[key] if isinstance(dtypes, dict) else dtypes, + unsqueeze=unsqueeze, + ) + for key in feature_columns + } + else: + features_tensor = convert_pandas_to_torch_tensor( + data, + columns=feature_columns, + column_dtypes=dtypes, + unsqueeze=unsqueeze, + ) + return features_tensor + + def _predict(self, tensor: torch.Tensor) -> pd.DataFrame: + """Handle actual prediction.""" + prediction = self.model(tensor).cpu().detach().numpy() + return pd.DataFrame(prediction, columns=["predictions"]) + def predict( self, data: DataBatchType, @@ -63,6 +101,7 @@ def predict( Union[List[str], List[List[str]], List[int], List[List[int]]] ] = None, dtype: Optional[torch.dtype] = None, + unsqueeze: bool = True, ) -> DataBatchType: """Run inference on data batch. @@ -74,12 +113,19 @@ def predict( array. feature_columns: The names or indices of the columns in the data to use as features to predict on. If this arg is a - list of lists, then the data batch will be converted into a + list of lists or a dict of string-list pairs, then the + data batch will be converted into a multiple tensors which are then concatenated before feeding into the model. This is useful for multi-input models. If None, then use all columns in ``data``. - dtype: The torch dtype to use when creating the torch tensor. - If set to None, then automatically infer the dtype. + dtype: The dtypes to use for the tensors. This should match the + format of ``feature_columns``, or be a single dtype, in which + case it will be applied to all tensors. + If None, then automatically infer the dtype. + unsqueeze_feature_tensors (bool): If set to True, the features tensors + will be unsqueezed (reshaped to (N, 1)) before being concatenated into + the final features tensor. Otherwise, they will be left as is, that is + (N, ). Defaults to True. Examples: @@ -119,6 +165,8 @@ def predict( Returns: DataBatchType: Prediction result. """ + self.model.eval() + if self.preprocessor: data = self.preprocessor.transform_batch(data) @@ -126,10 +174,7 @@ def predict( # If numpy array, then convert to pandas dataframe. data = pd.DataFrame(data) - # TODO(amog): Add `_convert_numpy_to_torch_tensor to use based on input type. - # Reduce conversion cost if input is in Numpy - tensor = convert_pandas_to_torch_tensor( - data, columns=feature_columns, column_dtypes=dtype + tensor = self._convert_to_tensor( + data, feature_columns=feature_columns, dtypes=dtype, unsqueeze=unsqueeze ) - prediction = self.model(tensor).cpu().detach().numpy() - return pd.DataFrame(prediction, columns=["predictions"]) + return self._predict(tensor) diff --git a/python/ray/ml/tests/_huggingface_data.py b/python/ray/ml/tests/_huggingface_data.py new file mode 100644 index 000000000000..670cc7102e14 --- /dev/null +++ b/python/ray/ml/tests/_huggingface_data.py @@ -0,0 +1,7 @@ +train_data = """ +{"input_ids":{"0":[238,8576,9441,2987,238,252,4657,74,4762,826,8576,428,466,609,6881,412,204,9441,311,2746,466,10816,168,99,150,192,112,14328,3983,112,4446,94,18288,4446,193,3983,98,3983,22171,95,19,201,6374,209,8576,218,198,3455,1972,428,310,201,5099,3242,227,281,8576,9441,2987,2553,1759,201,301,196,13996,1496,277,2330,1464,674,1898,307,742,3541,225,7514,14,54,719,274,198,4777,15522,209,19895,221,1341,1633,221,1759,201,322,301,198,1368,674,221,198,8576,843,209,2468,1795,223,198,1049,9595,218,13996,225,1563,277,582,6493,281,457,14371,201,198,1422,3373,7452,227,198,455,674,225,4687],"1":[198,239,21976,239,201,196,21657,1680,3773,5591,198,4196,218,4679,427,661,198,3518,1288,220,1051,516,889,3947,1922,2500,225,390,2065,744,872,198,7592,3773,239,1975,251,208,89,22351,239,209,252,261,674,959,1921,221,1462,201,7600,547,196,1178,4753,218,198,630,3591,263,8576,9441,1180,209,1831,322,7568,198,3621,2240,218,198,843,201,322,471,9575,5291,16591,967,201,781,281,1815,198,674,604,10344,1252,274,843,664,3147,320,209,13290,8751,8124,2528,6023,74,235,225,7445,10040,17384,241,11487,8950,857,1835,340,1382,22582,201,1008,296,8576,9441,1180,2436,21134,5337,19463,5161,209,240,1178],"2":[927,218,3776,8650,198,3355,209,261,674,268,83,2511,3472,258,8288,307,1010,268,78,209,252,530,922,296,3096,4354,221,1759,201,225,258,2630,307,857,2746,225,2761,2234,209,929,1634,201,322,1217,13362,4798,201,1008,296,326,5585,4615,221,1389,218,295,544,209,530,258,471,6459,587,8630,225,326,1446,1464,10260,843,209,5023,227,1669,4354,218,8576,9441,1180,201,8576,9441,2987,258,410,18878,201,445,196,6418,8895,20885,296,198,674,268,83,5585,4615,258,1055,221,2474,209,7514,14,54,719,602,1063,227,198,7072,296,198,1921,218,8576,466,8111,447,5474,274,198,4777,493,209,252],"3":[238,238,15114,238,238,252,662,296,1382,7784,9345,9441,1374,201,8576,9441,2987,301,196,13996,1496,277,2330,674,756,1997,1581,1694,218,196,1680,3773,225,1581,540,221,6455,872,5000,1754,209,8410,390,2848,718,5926,1184,277,1093,13659,296,10102,1136,13018,201,296,2081,6041,8040,718,7674,6917,14877,6632,225,8040,718,434,15744,5263,3973,209,261,1538,15606,718,196,843,218,11382,6455,201,5546,13230,5382,281,12903,295,684,293,15086,518,6145,718,225,1486,13196,281,543,390,13230,5382,209,261,2262,227,1140,1422,3371,263,198,7464,11718,8326,263,326,3854,1538,268,83,4156,466,646,496,6938,301,3941,201],"4":[198,548,301,15410,555,227,198,1538,209,13067,6455,201,198,1538,2081,1231,221,196,1306,201,756,2726,684,293,7376,968,225,1136,5168,6395,209,24281,198,906,1422,6455,390,1136,277,2737,747,6455,13516,227,1579,5860,1718,209,929,198,674,268,83,7131,201,3079,3411,390,13230,5382,201,766,218,642,1763,196,4014,6669,702,1413,824,221,198,1231,218,198,674,209,1415,390,471,2610,1076,1388,2377,3130,227,198,674,268,83,506,906,3641,828,201,1360,543,1581,196,1365,2978,1496,209,252,261,674,268,83,2094,1232,201,198,1028,73,52,58,1232,201,301,3493,547,3660,340,7784,9345,9441,209],"5":[1303,6455,201,1997,6322,1140,3773,1919,196,1362,277,1004,8550,218,198,11332,7464,466,2109,196,1136,301,3941,201,198,1538,5757,198,1136,1118,198,11332,221,1368,277,1432,209,240,1136,684,669,733,2109,503,277,1407,201,445,2081,684,293,4929,5291,6398,319,198,13876,218,548,2081,268,6398,209,5056,1136,499,196,1606,225,3924,218,3672,3069,307,444,11840,315,15041,209,3264,227,2849,2081,684,293,3423,227,196,1314,4490,209,1303,6493,201,2081,1170,798,573,1358,3851,16048,227,642,201,781,281,444,3668,2423,311,282,48,310,6506,1669,371,812,11285,573,307,5000,3704,209,5056,1136,499,2737,239],"6":[8303,2453,83,239,201,8436,5025,227,1140,1136,209,1065,390,4695,587,239,12021,8303,2453,239,201,400,390,3295,365,8436,295,3052,434,222,8830,11270,7936,9067,380,307,198,1422,225,684,2292,1546,371,1017,21071,196,1136,201,225,239,3455,8303,2453,83,239,201,400,390,8498,1895,198,674,225,3365,3326,1958,675,227,196,1136,209,2000,2627,3455,8303,2453,83,201,1140,1136,499,196,5025,239,8119,236,614,239,201,196,14628,277,1356,11949,6230,295,684,293,788,227,13970,225,3261,1579,8436,209,16500,471,490,7447,1531,3004,295,3326,642,8535,1958,11069,263,198,11332,466,8625,684,1980,365,239,4495],"7":[3683,239,225,2341,1118,198,11332,1582,897,234,1236,342,11840,6957,16326,201,198,1136,789,9878,684,9627,587,468,239,8576,3408,239,225,1638,896,23892,201,732,2293,5336,684,4849,5291,5000,2726,296,468,2287,8263,209,252,13441,1355,390,4695,587,1160,6555,466,23804,201,808,722,84,228,19466,201,12785,201,7096,320,225,21755,18607,853,209,13441,19466,684,12454,6555,307,7048,444,3423,8263,209,11070,223,1249,1990,410,5993,11950,198,285,1201,4825,732,221,196,1382,1249,209,1585,2495,221,2094,201,3335,2423,390,3495,227,198,5860,201,400,390,6908,587,1160,1579,13376,6136,307,198,2068,5860,201,196,3135],"8":[1317,223,340,989,1374,268,3588,218,10078,5751,227,1579,3773,4325,209,252,238,238,6851,238,238,252,261,674,3544,1417,661,198,3518,1288,220,1051,209,4679,406,2205,13619,3775,18,201,471,876,281,239,261,21976,239,201,390,196,21657,1680,3773,3808,218,18490,201,4899,541,199,690,201,225,1680,555,16914,2821,1563,2850,390,1843,620,340,198,3517,225,817,200,4614,3242,227,307,3473,209,10730,665,307,198,4679,406,1680,227,889,198,659,8110,6455,295,198,22724,2205,225,18991,1170,410,804,201,543,390,15994,578,227,198,5214,201,20891,307,444,22232,79,201,18766,4209,1531,21157,201,3457,239],"9":[14996,789,2989,209,239,261,697,906,2081,390,1524,14,23,8625,2408,2513,201,326,3006,4417,16271,539,7184,218,15101,516,9461,227,24374,1920,440,8607,1524,14,17,2293,5336,201,196,2476,5447,67,3294,2287,5253,16021,516,17113,14241,872,198,8576,516,3443,468,1398,440,225,1524,14,17,19,297,793,65,20714,215,201,196,11573,536,197,88,204,2246,3063,516,301,4796,2616,5744,196,20713,218,198,8576,209,9332,296,444,5920,5860,1718,201,1021,697,390,10455,227,2172,872,196,12040,7592,3773,876,281,1975,251,379,22351,201,6355,218,3223,5447,67,3294,3478,209,252,662,198,21976,4614,804,410,2523],"10":[201,198,3578,5080,1193,675,218,198,4679,406,2205,21878,198,3301,218,21337,2903,14454,221,1688,227,5355,642,263,6455,295,602,7936,1445,4679,427,7843,3132,221,198,1060,209,1831,319,1696,526,2125,227,444,4517,201,781,281,196,1991,13719,587,7592,3833,201,548,4953,3522,2931,1718,218,198,3775,18,207,1428,16464,209,1539,781,2541,201,315,283,2809,201,4228,749,554,16361,295,331,3669,675,342,1753,225,770,2220,587,198,10100,218,1975,251,379,22351,201,6386,227,198,9541,218,5447,67,3294,5668,3412,307,444,5164,201,8954,971,209,775,198,1049,582,201,2377,1347,4679,406,2205,3683,2341,227],"11":[1843,626,198,21976,221,1688,227,6259,444,1059,7969,209,282,15902,307,857,13855,225,9782,201,225,4439,296,198,3752,218,196,836,3547,1347,444,10100,201,198,3775,18,207,11695,1206,2341,227,3292,3230,10852,732,319,198,1049,582,2172,227,1546,198,4679,406,1060,4585,209,795,4471,1040,198,21976,268,83,14643,4417,201,3790,5491,22737,201,516,398,504,4755,699,2010,5125,201,301,15218,227,198,2867,1047,218,12106,71,337,90,221,1688,227,1474,2525,377,200,199,761,198,4239,89,3478,225,20416,198,1563,836,3547,201,198,4679,406,1796,295,398,7184,8625,218,3430,599,209,252,2304,264,1309,227],"12":[1021,1988,201,225,7259,1309,227,198,1214,5662,221,23734,4679,427,2433,320,2375,198,726,218,198,1060,296,198,4444,201,198,21976,390,3946,196,7944,1676,281,196,5860,221,198,4679,406,2205,2472,702,5228,281,326,13483,21340,2033,209,795,301,1311,277,3364,201,1650,201,281,1149,19033,15786,268,83,3441,201,8954,971,225,1975,251,379,22351,2341,227,1980,365,326,4479,7784,2920,2157,8263,1347,198,4444,201,4755,3947,307,444,5698,17446,209,9663,198,1117,218,19033,15786,371,198,6241,227,8223,3230,221,198,1060,296,4679,427,201,322,301,8954,971,268,83,1244,677,1940,6368,221,5290,196,664,5447],"13":[67,3294,4196,209,662,326,4579,4679,406,2033,896,5021,198,4444,1318,1149,198,506,8947,268,12199,277,1739,602,9684,14048,444,664,11418,4206,201,8625,8420,227,2109,695,1445,342,5860,198,21976,201,9271,22737,227,1343,1920,225,610,699,342,1434,281,2196,277,221,277,2429,209,7243,15215,17157,227,7235,548,702,3230,201,198,3775,18,207,21955,8954,971,225,1503,24815,198,7784,2920,8263,209,5056,2541,862,6084,444,4142,6294,221,1688,227,4555,444,4658,326,361,209,252,238,238,5823,238,238,252,19611,630,274,8576,9441,2987,959,550,1921,2393,263,8576,9441,1180,221,989,1462,201,296,1971,1921,3053],"14":[4691,550,526,209,261,2436,218,8576,9441,1180,201,21134,5337,19463,5161,201,1835,227,295,1496,274,8576,9441,2987,209,5823,630,1260,3032,496,544,209,929,198,1634,218,8576,9441,1180,201,198,3154,1260,196,1976,319,857,198,1737,3092,274,198,674,225,1241,543,2512,227,804,1533,274,198,843,209,4064,457,8472,201,8576,9441,2987,258,1898,274,4777,15522,466,526,258,1309,227,198,927,12182,227,1308,415,198,16590,2054,274,8576,9441,1180,201,225,543,398,410,3029,578,296,198,239,12252,239,3213,295,602,13884,196,664,6494,274,198,4777,428,209,11929,1006,221,326,2692,201,322,258,1821,295],"15":[198,1921,927,1535,8576,9441,2987,227,293,198,843,268,455,4099,7056,466,732,8576,9441,1180,398,3034,196,1178,3126,218,5429,225,12738,661,1921,1309,227,198,5770,2341,201,198,1368,674,1706,642,196,6241,227,7127,2451,198,1386,2773,218,8576,9441,1180,1309,227,812,263,198,1049,5770,209,357,1513,227,742,3541,3154,340,198,1382,1374,201,1921,630,258,471,8650,307,7514,14,54,719,14,261,1446,13395,258,1617,15415,9455,558,9884,706,69,201,732,198,3355,258,1617,307,11856,1652,9455,12645,74,5435,201,11490,20317,12329,10190,201,384,424,11300,21187,10198,201,742,14619,9765,368,5161,225,10830,288]},"attention_mask":{"0":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"1":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"2":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"3":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"4":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"5":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"6":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"7":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"8":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"9":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"10":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"11":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"12":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"13":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"14":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"15":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]},"labels":{"0":[238,8576,9441,2987,238,252,4657,74,4762,826,8576,428,466,609,6881,412,204,9441,311,2746,466,10816,168,99,150,192,112,14328,3983,112,4446,94,18288,4446,193,3983,98,3983,22171,95,19,201,6374,209,8576,218,198,3455,1972,428,310,201,5099,3242,227,281,8576,9441,2987,2553,1759,201,301,196,13996,1496,277,2330,1464,674,1898,307,742,3541,225,7514,14,54,719,274,198,4777,15522,209,19895,221,1341,1633,221,1759,201,322,301,198,1368,674,221,198,8576,843,209,2468,1795,223,198,1049,9595,218,13996,225,1563,277,582,6493,281,457,14371,201,198,1422,3373,7452,227,198,455,674,225,4687],"1":[198,239,21976,239,201,196,21657,1680,3773,5591,198,4196,218,4679,427,661,198,3518,1288,220,1051,516,889,3947,1922,2500,225,390,2065,744,872,198,7592,3773,239,1975,251,208,89,22351,239,209,252,261,674,959,1921,221,1462,201,7600,547,196,1178,4753,218,198,630,3591,263,8576,9441,1180,209,1831,322,7568,198,3621,2240,218,198,843,201,322,471,9575,5291,16591,967,201,781,281,1815,198,674,604,10344,1252,274,843,664,3147,320,209,13290,8751,8124,2528,6023,74,235,225,7445,10040,17384,241,11487,8950,857,1835,340,1382,22582,201,1008,296,8576,9441,1180,2436,21134,5337,19463,5161,209,240,1178],"2":[927,218,3776,8650,198,3355,209,261,674,268,83,2511,3472,258,8288,307,1010,268,78,209,252,530,922,296,3096,4354,221,1759,201,225,258,2630,307,857,2746,225,2761,2234,209,929,1634,201,322,1217,13362,4798,201,1008,296,326,5585,4615,221,1389,218,295,544,209,530,258,471,6459,587,8630,225,326,1446,1464,10260,843,209,5023,227,1669,4354,218,8576,9441,1180,201,8576,9441,2987,258,410,18878,201,445,196,6418,8895,20885,296,198,674,268,83,5585,4615,258,1055,221,2474,209,7514,14,54,719,602,1063,227,198,7072,296,198,1921,218,8576,466,8111,447,5474,274,198,4777,493,209,252],"3":[238,238,15114,238,238,252,662,296,1382,7784,9345,9441,1374,201,8576,9441,2987,301,196,13996,1496,277,2330,674,756,1997,1581,1694,218,196,1680,3773,225,1581,540,221,6455,872,5000,1754,209,8410,390,2848,718,5926,1184,277,1093,13659,296,10102,1136,13018,201,296,2081,6041,8040,718,7674,6917,14877,6632,225,8040,718,434,15744,5263,3973,209,261,1538,15606,718,196,843,218,11382,6455,201,5546,13230,5382,281,12903,295,684,293,15086,518,6145,718,225,1486,13196,281,543,390,13230,5382,209,261,2262,227,1140,1422,3371,263,198,7464,11718,8326,263,326,3854,1538,268,83,4156,466,646,496,6938,301,3941,201],"4":[198,548,301,15410,555,227,198,1538,209,13067,6455,201,198,1538,2081,1231,221,196,1306,201,756,2726,684,293,7376,968,225,1136,5168,6395,209,24281,198,906,1422,6455,390,1136,277,2737,747,6455,13516,227,1579,5860,1718,209,929,198,674,268,83,7131,201,3079,3411,390,13230,5382,201,766,218,642,1763,196,4014,6669,702,1413,824,221,198,1231,218,198,674,209,1415,390,471,2610,1076,1388,2377,3130,227,198,674,268,83,506,906,3641,828,201,1360,543,1581,196,1365,2978,1496,209,252,261,674,268,83,2094,1232,201,198,1028,73,52,58,1232,201,301,3493,547,3660,340,7784,9345,9441,209],"5":[1303,6455,201,1997,6322,1140,3773,1919,196,1362,277,1004,8550,218,198,11332,7464,466,2109,196,1136,301,3941,201,198,1538,5757,198,1136,1118,198,11332,221,1368,277,1432,209,240,1136,684,669,733,2109,503,277,1407,201,445,2081,684,293,4929,5291,6398,319,198,13876,218,548,2081,268,6398,209,5056,1136,499,196,1606,225,3924,218,3672,3069,307,444,11840,315,15041,209,3264,227,2849,2081,684,293,3423,227,196,1314,4490,209,1303,6493,201,2081,1170,798,573,1358,3851,16048,227,642,201,781,281,444,3668,2423,311,282,48,310,6506,1669,371,812,11285,573,307,5000,3704,209,5056,1136,499,2737,239],"6":[8303,2453,83,239,201,8436,5025,227,1140,1136,209,1065,390,4695,587,239,12021,8303,2453,239,201,400,390,3295,365,8436,295,3052,434,222,8830,11270,7936,9067,380,307,198,1422,225,684,2292,1546,371,1017,21071,196,1136,201,225,239,3455,8303,2453,83,239,201,400,390,8498,1895,198,674,225,3365,3326,1958,675,227,196,1136,209,2000,2627,3455,8303,2453,83,201,1140,1136,499,196,5025,239,8119,236,614,239,201,196,14628,277,1356,11949,6230,295,684,293,788,227,13970,225,3261,1579,8436,209,16500,471,490,7447,1531,3004,295,3326,642,8535,1958,11069,263,198,11332,466,8625,684,1980,365,239,4495],"7":[3683,239,225,2341,1118,198,11332,1582,897,234,1236,342,11840,6957,16326,201,198,1136,789,9878,684,9627,587,468,239,8576,3408,239,225,1638,896,23892,201,732,2293,5336,684,4849,5291,5000,2726,296,468,2287,8263,209,252,13441,1355,390,4695,587,1160,6555,466,23804,201,808,722,84,228,19466,201,12785,201,7096,320,225,21755,18607,853,209,13441,19466,684,12454,6555,307,7048,444,3423,8263,209,11070,223,1249,1990,410,5993,11950,198,285,1201,4825,732,221,196,1382,1249,209,1585,2495,221,2094,201,3335,2423,390,3495,227,198,5860,201,400,390,6908,587,1160,1579,13376,6136,307,198,2068,5860,201,196,3135],"8":[1317,223,340,989,1374,268,3588,218,10078,5751,227,1579,3773,4325,209,252,238,238,6851,238,238,252,261,674,3544,1417,661,198,3518,1288,220,1051,209,4679,406,2205,13619,3775,18,201,471,876,281,239,261,21976,239,201,390,196,21657,1680,3773,3808,218,18490,201,4899,541,199,690,201,225,1680,555,16914,2821,1563,2850,390,1843,620,340,198,3517,225,817,200,4614,3242,227,307,3473,209,10730,665,307,198,4679,406,1680,227,889,198,659,8110,6455,295,198,22724,2205,225,18991,1170,410,804,201,543,390,15994,578,227,198,5214,201,20891,307,444,22232,79,201,18766,4209,1531,21157,201,3457,239],"9":[14996,789,2989,209,239,261,697,906,2081,390,1524,14,23,8625,2408,2513,201,326,3006,4417,16271,539,7184,218,15101,516,9461,227,24374,1920,440,8607,1524,14,17,2293,5336,201,196,2476,5447,67,3294,2287,5253,16021,516,17113,14241,872,198,8576,516,3443,468,1398,440,225,1524,14,17,19,297,793,65,20714,215,201,196,11573,536,197,88,204,2246,3063,516,301,4796,2616,5744,196,20713,218,198,8576,209,9332,296,444,5920,5860,1718,201,1021,697,390,10455,227,2172,872,196,12040,7592,3773,876,281,1975,251,379,22351,201,6355,218,3223,5447,67,3294,3478,209,252,662,198,21976,4614,804,410,2523],"10":[201,198,3578,5080,1193,675,218,198,4679,406,2205,21878,198,3301,218,21337,2903,14454,221,1688,227,5355,642,263,6455,295,602,7936,1445,4679,427,7843,3132,221,198,1060,209,1831,319,1696,526,2125,227,444,4517,201,781,281,196,1991,13719,587,7592,3833,201,548,4953,3522,2931,1718,218,198,3775,18,207,1428,16464,209,1539,781,2541,201,315,283,2809,201,4228,749,554,16361,295,331,3669,675,342,1753,225,770,2220,587,198,10100,218,1975,251,379,22351,201,6386,227,198,9541,218,5447,67,3294,5668,3412,307,444,5164,201,8954,971,209,775,198,1049,582,201,2377,1347,4679,406,2205,3683,2341,227],"11":[1843,626,198,21976,221,1688,227,6259,444,1059,7969,209,282,15902,307,857,13855,225,9782,201,225,4439,296,198,3752,218,196,836,3547,1347,444,10100,201,198,3775,18,207,11695,1206,2341,227,3292,3230,10852,732,319,198,1049,582,2172,227,1546,198,4679,406,1060,4585,209,795,4471,1040,198,21976,268,83,14643,4417,201,3790,5491,22737,201,516,398,504,4755,699,2010,5125,201,301,15218,227,198,2867,1047,218,12106,71,337,90,221,1688,227,1474,2525,377,200,199,761,198,4239,89,3478,225,20416,198,1563,836,3547,201,198,4679,406,1796,295,398,7184,8625,218,3430,599,209,252,2304,264,1309,227],"12":[1021,1988,201,225,7259,1309,227,198,1214,5662,221,23734,4679,427,2433,320,2375,198,726,218,198,1060,296,198,4444,201,198,21976,390,3946,196,7944,1676,281,196,5860,221,198,4679,406,2205,2472,702,5228,281,326,13483,21340,2033,209,795,301,1311,277,3364,201,1650,201,281,1149,19033,15786,268,83,3441,201,8954,971,225,1975,251,379,22351,2341,227,1980,365,326,4479,7784,2920,2157,8263,1347,198,4444,201,4755,3947,307,444,5698,17446,209,9663,198,1117,218,19033,15786,371,198,6241,227,8223,3230,221,198,1060,296,4679,427,201,322,301,8954,971,268,83,1244,677,1940,6368,221,5290,196,664,5447],"13":[67,3294,4196,209,662,326,4579,4679,406,2033,896,5021,198,4444,1318,1149,198,506,8947,268,12199,277,1739,602,9684,14048,444,664,11418,4206,201,8625,8420,227,2109,695,1445,342,5860,198,21976,201,9271,22737,227,1343,1920,225,610,699,342,1434,281,2196,277,221,277,2429,209,7243,15215,17157,227,7235,548,702,3230,201,198,3775,18,207,21955,8954,971,225,1503,24815,198,7784,2920,8263,209,5056,2541,862,6084,444,4142,6294,221,1688,227,4555,444,4658,326,361,209,252,238,238,5823,238,238,252,19611,630,274,8576,9441,2987,959,550,1921,2393,263,8576,9441,1180,221,989,1462,201,296,1971,1921,3053],"14":[4691,550,526,209,261,2436,218,8576,9441,1180,201,21134,5337,19463,5161,201,1835,227,295,1496,274,8576,9441,2987,209,5823,630,1260,3032,496,544,209,929,198,1634,218,8576,9441,1180,201,198,3154,1260,196,1976,319,857,198,1737,3092,274,198,674,225,1241,543,2512,227,804,1533,274,198,843,209,4064,457,8472,201,8576,9441,2987,258,1898,274,4777,15522,466,526,258,1309,227,198,927,12182,227,1308,415,198,16590,2054,274,8576,9441,1180,201,225,543,398,410,3029,578,296,198,239,12252,239,3213,295,602,13884,196,664,6494,274,198,4777,428,209,11929,1006,221,326,2692,201,322,258,1821,295],"15":[198,1921,927,1535,8576,9441,2987,227,293,198,843,268,455,4099,7056,466,732,8576,9441,1180,398,3034,196,1178,3126,218,5429,225,12738,661,1921,1309,227,198,5770,2341,201,198,1368,674,1706,642,196,6241,227,7127,2451,198,1386,2773,218,8576,9441,1180,1309,227,812,263,198,1049,5770,209,357,1513,227,742,3541,3154,340,198,1382,1374,201,1921,630,258,471,8650,307,7514,14,54,719,14,261,1446,13395,258,1617,15415,9455,558,9884,706,69,201,732,198,3355,258,1617,307,11856,1652,9455,12645,74,5435,201,11490,20317,12329,10190,201,384,424,11300,21187,10198,201,742,14619,9765,368,5161,225,10830,288]}} +""" + +validation_data = """ +{"input_ids":{"0":[238,19210,12331,2208,283,238,252,19210,12331,2208,283,201,876,281,198,2364,17152,371,1512,17152,201,301,196,1293,218,472,655,204,17152,340,198,2225,2477,4575,201,7662,4116,225,2773,218,198,1885,4116,209,530,301,5952,3130,227,198,1013,17152,201,282,14,196,760,12898,209,530,977,2156,227,196,2591,218,2838,3171,311,1678,221,310,225,196,2223,218,605,21695,311,1086,4336,310,201,225,10043,196,20778,3289,218,23714,209,357,1151,201,198,7162,9879,390,4071,201,669,3179,239,17152,1994,239,263,12897,209,259,761,6395,221,198,3271,201,6453,7144,400,390,3493,307,198,5138,274,578,227,196],"1":[544,806,6066,2753,587,23664,851,229,17469,209,19210,12331,2208,283,301,196,3879,928,21008,204,2970,201,225,301,3991,6637,1919,17152,224,1637,201,3223,1118,198,1038,11303,209,252,238,238,10150,238,238,252,19210,12331,2208,283,301,196,1178,18183,571,220,201,296,196,1825,2591,578,227,2838,10892,311,1678,221,310,225,19186,578,227,502,459,605,21695,311,920,459,1086,4336,310,201,1360,198,7162,9879,6637,221,17152,224,1637,390,3076,1625,459,3075,3171,311,670,459,916,221,310,908,225,11698,796,511,629,459,289,511,289,4890,311,245,511,502,459,493,511,670,4336,310,209,4064,548,18183],"2":[571,738,201,7162,9879,490,196,2804,377,316,559,8605,400,543,2686,14864,221,1688,227,2156,201,221,196,2169,994,5080,68,1162,215,311,24820,1236,310,209,795,977,2100,951,1696,196,544,274,2246,7162,9879,201,445,24202,227,2109,1621,245,459,289,818,274,3097,5503,209,252,261,455,3289,218,224,350,73,343,11051,301,4579,296,196,1178,201,20234,4371,3688,3289,218,23714,209,261,3097,496,301,198,239,1643,17959,239,201,225,499,12466,24880,5571,788,274,1643,5003,6490,440,198,548,301,198,239,2439,280,239,201,400,499,9629,6639,8864,201,225,301,788,274,6205,371,194,10866,198,6490],"3":[209,4891,957,201,198,1344,472,655,301,198,1643,17959,201,225,198,1557,301,198,2439,280,209,252,261,377,316,559,8605,301,2370,4071,2660,201,296,9915,295,767,1130,298,201,225,4578,4052,209,261,1994,6334,3720,296,7162,9879,669,3440,550,12897,209,795,6395,1080,201,221,1151,201,198,1994,24422,4624,1215,6233,197,301,3460,227,196,4279,3794,201,445,198,3794,301,4752,578,307,198,8126,218,12897,201,11701,198,1994,24422,209,252,261,8445,7339,218,282,14,12331,2208,283,301,198,1013,17152,201,19210,196,760,12898,209,261,506,1293,390,1365,1553,201,225,684,293,6982,982,778,1053,201],"4":[1360,15608,2521,390,9718,227,2100,221,198,4520,1211,444,8564,804,410,14358,209,261,506,1293,684,293,8793,307,196,845,218,6498,466,252,261,494,246,5398,218,282,14,196,760,12898,10043,496,371,604,20338,263,198,24534,201,400,390,10581,221,282,14,12331,2208,283,209,252,261,20338,263,198,23714,218,282,14,196,760,12898,390,1994,371,1994,277,22181,201,732,1413,218,282,14,12331,2208,283,390,1969,371,1969,277,22181,209,252,261,24534,218,198,472,655,218,282,14,196,760,12898,301,12960,371,1994,201,732,295,218,282,14,12331,2208,283,301,942,5676,1969,371,1365,8202,1994,209],"5":[252,238,238,5971,7308,238,238,252,14035,282,14,12331,2208,283,3484,3956,13559,646,543,490,8498,227,196,1044,422,571,2591,218,3105,459,5626,22608,311,428,511,245,459,428,511,428,221,310,201,6844,5020,10719,319,196,4370,3199,2892,209,259,761,4183,6395,221,3271,708,196,5527,24820,744,2476,201,2821,4093,301,4947,6144,201,225,196,2804,277,24110,2956,209,261,2476,9651,198,7144,274,578,227,837,2200,201,8326,263,198,5010,201,6386,227,468,4405,343,11051,209,17172,7600,7144,390,1014,227,293,239,20148,1635,239,225,684,293,824,1895,198,544,209,252,261,7144,15638,319,1948,201,225],"6":[198,17469,9408,227,198,1371,3239,756,543,11642,296,198,7153,20898,201,577,2464,263,22762,8222,858,851,209,795,2570,11533,697,24820,745,225,15964,274,916,459,2380,1651,209,929,198,1368,24820,84,201,198,16695,3544,263,196,740,7727,227,198,4047,201,225,3313,745,196,219,2121,229,13643,209,261,16696,390,6678,1809,221,198,4520,201,225,390,7808,876,201,1360,543,390,876,227,293,5992,218,15590,5536,2515,6846,209,530,301,2729,295,669,245,768,6820,221,1621,464,562,802,14272,227,198,219,2121,229,5427,209,1438,543,3484,196,1044,422,571,2591,218,916,1999,311,796,511,5931,221,310],"7":[201,198,16696,3489,444,2515,6846,225,1046,444,4047,4658,209,252,238,238,14720,238,238,252,19210,12331,2208,283,301,824,1896,198,1128,277,2225,2477,4575,340,2518,7984,227,198,14780,225,22054,201,410,867,198,18294,4116,209,530,301,471,1474,221,659,218,198,7662,4116,201,669,6157,340,198,3161,1657,218,4968,563,201,225,1008,669,198,1128,277,1639,2076,218,198,1885,4116,209,261,2518,1942,5892,390,824,221,198,10768,214,74,1907,236,1162,70,74,412,213,225,18700,17840,65,201,3950,198,21427,9743,234,209,252,261,1293,684,293,4695,587,784,1851,269,1320,4344,5892,201,496,6098,1866],"8":[201,225,697,400,490,9068,1166,1309,227,1119,5501,1866,15321,201,4692,1309,227,8552,227,198,1492,5039,209,261,455,218,1021,301,198,1866,218,7162,9879,340,2518,7984,201,400,490,504,3242,227,281,198,239,16501,277,6271,17152,239,209,261,5892,221,198,7662,4116,390,4344,340,1413,221,198,2477,4575,209,261,1244,4344,1866,301,824,221,198,7533,466,12828,340,198,339,419,320,4651,352,69,360,4344,340,1413,5464,221,198,1182,4116,371,1545,7484,209,252,24135,83,490,504,746,227,14275,282,14,12331,2208,283,227,752,2952,201,4583,548,2364,1293,781,281,198,15494,21089,201,24872,224,330],"9":[260,283,209,5507,11329,225,5989,201,496,1079,17152,17469,360,1055,340,6066,2264,346,221,5222,204,197,201,445,198,1293,992,410,1638,1860,817,209,252,238,238,19976,238,238,252,11312,282,14,12331,2208,283,2253,263,198,10077,22975,319,21762,218,796,459,4178,3224,311,796,459,4468,18,1841,310,201,1360,410,7736,15408,702,1834,230,311,5722,1841,310,209,1065,5463,2804,8138,7075,201,781,281,10578,371,2804,11832,201,225,2253,221,12782,371,942,86,1656,201,16062,319,1948,227,7710,209,252,261,7546,218,282,14,12331,2208,283,3223,5886,218,548,219,2121,229,18499,209,1850,1766,6063,1466,201],"10":[230,1894,13799,83,201,3304,159,885,776,201,1322,11207,225,4291,18680,16328,21077,209,252,261,697,472,655,204,17152,1293,19210,12331,2208,283,201,282,14,196,760,12898,225,1853,575,635,83,5746,711,71,10491,390,12202,227,198,697,876,1293,218,198,4926,1091,930,327,247,2236,73,2401,4403,440,198,1293,263,282,14,12331,2208,283,499,410,504,1246,209,252,19210,12331,2208,283,301,22081,227,198,4516,267,1867,4219,7024,201,2035,307,198,14477,1026,5611,375,13489,283,9769,271,738,209,1647,322,301,5431,824,221,1013,7162,9879,201,198,4516,499,669,504,1809,221,19676,282,14,12331,2208,283],"11":[201,756,3119,9025,218,198,8284,307,282,14,196,760,12898,990,410,293,6620,573,209,252,238,238,6870,11391,238,238,252,19210,12331,2208,283,301,13924,239,3879,928,21008,204,239,281,196,2970,246,3760,225,258,4478,221,239,261,10835,11207,239,196,18409,926,1545,4676,647,209,530,977,13762,309,1365,823,15210,225,977,293,1574,5739,201,16791,7010,201,684,2457,371,4192,68,665,209,3965,198,23714,225,198,22051,218,282,14,12331,2208,283,4607,239,13123,239,1969,12394,201,225,659,218,198,15891,218,198,2374,15444,679,210,1215,390,15494,209,261,17588,390,198,21202,523,955,225,198,239,4936],"12":[22774,239,311,21201,310,209,261,11588,218,282,14,12331,2208,283,301,578,227,697,1696,4014,702,295,218,282,14,196,760,12898,201,225,198,2364,1293,301,1535,227,490,196,3077,12693,418,209,252,317,706,9879,390,3223,214,736,1919,17152,224,1637,201,1360,2875,22218,616,296,11820,20305,371,2439,84,234,11207,3019,8900,221,22179,223,642,573,201,227,1313,642,227,293,6637,221,196,8946,371,307,1414,209,357,1465,201,493,562,3075,22,194,218,282,14,12331,2208,283,360,6637,1896,1399,225,1182,3054,201,218,400,428,562,4683,18,194,311,6600,1124,310,258,6637,221,198,1038,11303,311],"13":[867,198,7484,7362,310,209,261,8240,5360,2892,274,282,14,12331,2208,283,301,196,1044,422,571,2591,218,6286,1999,311,428,511,493,221,310,209,252,13773,257,21384,5135,274,282,14,12331,2208,283,390,699,1921,201,225,1743,7928,390,1490,1365,1669,209,252,238,238,8937,16097,1514,238,238,252,19210,12331,2208,283,258,455,1782,196,12072,22267,1221,307,5064,317,909,19649,221,198,8813,4615,218,342,5599,65,318,4881,1240,201,1573,221,5986,24,209,2881,1221,258,24872,12331,2208,283,201,1211,317,909,19649,268,3301,218,198,4013,24872,319,295,582,1331,610,1178,18183,571,738,209,252,282,14],"14":[12331,2208,283,301,198,3108,1293,218,198,4013,19210,17518,201,20690,201,281,4729,307,300,4340,3945,218,198,2392,4774,263,16455,1839,318,249,22701,1411,209,7666,227,295,4568,201,11527,12061,1080,198,1293,398,504,3242,227,307,951,1579,2850,201,867,5265,257,283,1261,4487,14166,523,2215,201,22956,225,19210,18381,215,282,14,3248,1775,277,18903,201,14300,201,225,471,1080,14415,17518,268,83,5278,218,198,4013,398,504,17228,1040,1994,22708,307,2489,297,633,66,302,201,19912,1061,3119,13525,967,218,3108,1293,311,274,19210,282,14,3248,1775,277,18903,201,14300,310,21468,274,19210,17518,201,20690,209,252],"15":[261,3108,7139,218,19210,12331,2208,283,258,196,10529,7769,3941,307,23601,559,3059,299,85,215,221,5629,209,530,1501,340,5197,2719,4800,396,110,318,920,2719,3434,396,110,338,201,996,681,246,1245,201,9704,311,4067,5549,371,1342,1957,4587,218,9901,787,6012,310,201,445,857,322,225,198,1346,222,395,14403,490,1211,504,1777,209,252,261,1512,1221,274,282,14,12331,2208,283,7676,307,198,9758,225,15124,447,8473,301,239,2364,17152,239,201,445,198,1293,301,471,3991,876,281,198,239,1512,17152,239,209,252,238,3180,5577,4593,238,252,1934,8314,12825,3180,5577,4593,201,17920,201,17920,37]},"attention_mask":{"0":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"1":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"2":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"3":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"4":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"5":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"6":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"7":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"8":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"9":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"10":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"11":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"12":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"13":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"14":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1],"15":[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1]},"labels":{"0":[238,19210,12331,2208,283,238,252,19210,12331,2208,283,201,876,281,198,2364,17152,371,1512,17152,201,301,196,1293,218,472,655,204,17152,340,198,2225,2477,4575,201,7662,4116,225,2773,218,198,1885,4116,209,530,301,5952,3130,227,198,1013,17152,201,282,14,196,760,12898,209,530,977,2156,227,196,2591,218,2838,3171,311,1678,221,310,225,196,2223,218,605,21695,311,1086,4336,310,201,225,10043,196,20778,3289,218,23714,209,357,1151,201,198,7162,9879,390,4071,201,669,3179,239,17152,1994,239,263,12897,209,259,761,6395,221,198,3271,201,6453,7144,400,390,3493,307,198,5138,274,578,227,196],"1":[544,806,6066,2753,587,23664,851,229,17469,209,19210,12331,2208,283,301,196,3879,928,21008,204,2970,201,225,301,3991,6637,1919,17152,224,1637,201,3223,1118,198,1038,11303,209,252,238,238,10150,238,238,252,19210,12331,2208,283,301,196,1178,18183,571,220,201,296,196,1825,2591,578,227,2838,10892,311,1678,221,310,225,19186,578,227,502,459,605,21695,311,920,459,1086,4336,310,201,1360,198,7162,9879,6637,221,17152,224,1637,390,3076,1625,459,3075,3171,311,670,459,916,221,310,908,225,11698,796,511,629,459,289,511,289,4890,311,245,511,502,459,493,511,670,4336,310,209,4064,548,18183],"2":[571,738,201,7162,9879,490,196,2804,377,316,559,8605,400,543,2686,14864,221,1688,227,2156,201,221,196,2169,994,5080,68,1162,215,311,24820,1236,310,209,795,977,2100,951,1696,196,544,274,2246,7162,9879,201,445,24202,227,2109,1621,245,459,289,818,274,3097,5503,209,252,261,455,3289,218,224,350,73,343,11051,301,4579,296,196,1178,201,20234,4371,3688,3289,218,23714,209,261,3097,496,301,198,239,1643,17959,239,201,225,499,12466,24880,5571,788,274,1643,5003,6490,440,198,548,301,198,239,2439,280,239,201,400,499,9629,6639,8864,201,225,301,788,274,6205,371,194,10866,198,6490],"3":[209,4891,957,201,198,1344,472,655,301,198,1643,17959,201,225,198,1557,301,198,2439,280,209,252,261,377,316,559,8605,301,2370,4071,2660,201,296,9915,295,767,1130,298,201,225,4578,4052,209,261,1994,6334,3720,296,7162,9879,669,3440,550,12897,209,795,6395,1080,201,221,1151,201,198,1994,24422,4624,1215,6233,197,301,3460,227,196,4279,3794,201,445,198,3794,301,4752,578,307,198,8126,218,12897,201,11701,198,1994,24422,209,252,261,8445,7339,218,282,14,12331,2208,283,301,198,1013,17152,201,19210,196,760,12898,209,261,506,1293,390,1365,1553,201,225,684,293,6982,982,778,1053,201],"4":[1360,15608,2521,390,9718,227,2100,221,198,4520,1211,444,8564,804,410,14358,209,261,506,1293,684,293,8793,307,196,845,218,6498,466,252,261,494,246,5398,218,282,14,196,760,12898,10043,496,371,604,20338,263,198,24534,201,400,390,10581,221,282,14,12331,2208,283,209,252,261,20338,263,198,23714,218,282,14,196,760,12898,390,1994,371,1994,277,22181,201,732,1413,218,282,14,12331,2208,283,390,1969,371,1969,277,22181,209,252,261,24534,218,198,472,655,218,282,14,196,760,12898,301,12960,371,1994,201,732,295,218,282,14,12331,2208,283,301,942,5676,1969,371,1365,8202,1994,209],"5":[252,238,238,5971,7308,238,238,252,14035,282,14,12331,2208,283,3484,3956,13559,646,543,490,8498,227,196,1044,422,571,2591,218,3105,459,5626,22608,311,428,511,245,459,428,511,428,221,310,201,6844,5020,10719,319,196,4370,3199,2892,209,259,761,4183,6395,221,3271,708,196,5527,24820,744,2476,201,2821,4093,301,4947,6144,201,225,196,2804,277,24110,2956,209,261,2476,9651,198,7144,274,578,227,837,2200,201,8326,263,198,5010,201,6386,227,468,4405,343,11051,209,17172,7600,7144,390,1014,227,293,239,20148,1635,239,225,684,293,824,1895,198,544,209,252,261,7144,15638,319,1948,201,225],"6":[198,17469,9408,227,198,1371,3239,756,543,11642,296,198,7153,20898,201,577,2464,263,22762,8222,858,851,209,795,2570,11533,697,24820,745,225,15964,274,916,459,2380,1651,209,929,198,1368,24820,84,201,198,16695,3544,263,196,740,7727,227,198,4047,201,225,3313,745,196,219,2121,229,13643,209,261,16696,390,6678,1809,221,198,4520,201,225,390,7808,876,201,1360,543,390,876,227,293,5992,218,15590,5536,2515,6846,209,530,301,2729,295,669,245,768,6820,221,1621,464,562,802,14272,227,198,219,2121,229,5427,209,1438,543,3484,196,1044,422,571,2591,218,916,1999,311,796,511,5931,221,310],"7":[201,198,16696,3489,444,2515,6846,225,1046,444,4047,4658,209,252,238,238,14720,238,238,252,19210,12331,2208,283,301,824,1896,198,1128,277,2225,2477,4575,340,2518,7984,227,198,14780,225,22054,201,410,867,198,18294,4116,209,530,301,471,1474,221,659,218,198,7662,4116,201,669,6157,340,198,3161,1657,218,4968,563,201,225,1008,669,198,1128,277,1639,2076,218,198,1885,4116,209,261,2518,1942,5892,390,824,221,198,10768,214,74,1907,236,1162,70,74,412,213,225,18700,17840,65,201,3950,198,21427,9743,234,209,252,261,1293,684,293,4695,587,784,1851,269,1320,4344,5892,201,496,6098,1866],"8":[201,225,697,400,490,9068,1166,1309,227,1119,5501,1866,15321,201,4692,1309,227,8552,227,198,1492,5039,209,261,455,218,1021,301,198,1866,218,7162,9879,340,2518,7984,201,400,490,504,3242,227,281,198,239,16501,277,6271,17152,239,209,261,5892,221,198,7662,4116,390,4344,340,1413,221,198,2477,4575,209,261,1244,4344,1866,301,824,221,198,7533,466,12828,340,198,339,419,320,4651,352,69,360,4344,340,1413,5464,221,198,1182,4116,371,1545,7484,209,252,24135,83,490,504,746,227,14275,282,14,12331,2208,283,227,752,2952,201,4583,548,2364,1293,781,281,198,15494,21089,201,24872,224,330],"9":[260,283,209,5507,11329,225,5989,201,496,1079,17152,17469,360,1055,340,6066,2264,346,221,5222,204,197,201,445,198,1293,992,410,1638,1860,817,209,252,238,238,19976,238,238,252,11312,282,14,12331,2208,283,2253,263,198,10077,22975,319,21762,218,796,459,4178,3224,311,796,459,4468,18,1841,310,201,1360,410,7736,15408,702,1834,230,311,5722,1841,310,209,1065,5463,2804,8138,7075,201,781,281,10578,371,2804,11832,201,225,2253,221,12782,371,942,86,1656,201,16062,319,1948,227,7710,209,252,261,7546,218,282,14,12331,2208,283,3223,5886,218,548,219,2121,229,18499,209,1850,1766,6063,1466,201],"10":[230,1894,13799,83,201,3304,159,885,776,201,1322,11207,225,4291,18680,16328,21077,209,252,261,697,472,655,204,17152,1293,19210,12331,2208,283,201,282,14,196,760,12898,225,1853,575,635,83,5746,711,71,10491,390,12202,227,198,697,876,1293,218,198,4926,1091,930,327,247,2236,73,2401,4403,440,198,1293,263,282,14,12331,2208,283,499,410,504,1246,209,252,19210,12331,2208,283,301,22081,227,198,4516,267,1867,4219,7024,201,2035,307,198,14477,1026,5611,375,13489,283,9769,271,738,209,1647,322,301,5431,824,221,1013,7162,9879,201,198,4516,499,669,504,1809,221,19676,282,14,12331,2208,283],"11":[201,756,3119,9025,218,198,8284,307,282,14,196,760,12898,990,410,293,6620,573,209,252,238,238,6870,11391,238,238,252,19210,12331,2208,283,301,13924,239,3879,928,21008,204,239,281,196,2970,246,3760,225,258,4478,221,239,261,10835,11207,239,196,18409,926,1545,4676,647,209,530,977,13762,309,1365,823,15210,225,977,293,1574,5739,201,16791,7010,201,684,2457,371,4192,68,665,209,3965,198,23714,225,198,22051,218,282,14,12331,2208,283,4607,239,13123,239,1969,12394,201,225,659,218,198,15891,218,198,2374,15444,679,210,1215,390,15494,209,261,17588,390,198,21202,523,955,225,198,239,4936],"12":[22774,239,311,21201,310,209,261,11588,218,282,14,12331,2208,283,301,578,227,697,1696,4014,702,295,218,282,14,196,760,12898,201,225,198,2364,1293,301,1535,227,490,196,3077,12693,418,209,252,317,706,9879,390,3223,214,736,1919,17152,224,1637,201,1360,2875,22218,616,296,11820,20305,371,2439,84,234,11207,3019,8900,221,22179,223,642,573,201,227,1313,642,227,293,6637,221,196,8946,371,307,1414,209,357,1465,201,493,562,3075,22,194,218,282,14,12331,2208,283,360,6637,1896,1399,225,1182,3054,201,218,400,428,562,4683,18,194,311,6600,1124,310,258,6637,221,198,1038,11303,311],"13":[867,198,7484,7362,310,209,261,8240,5360,2892,274,282,14,12331,2208,283,301,196,1044,422,571,2591,218,6286,1999,311,428,511,493,221,310,209,252,13773,257,21384,5135,274,282,14,12331,2208,283,390,699,1921,201,225,1743,7928,390,1490,1365,1669,209,252,238,238,8937,16097,1514,238,238,252,19210,12331,2208,283,258,455,1782,196,12072,22267,1221,307,5064,317,909,19649,221,198,8813,4615,218,342,5599,65,318,4881,1240,201,1573,221,5986,24,209,2881,1221,258,24872,12331,2208,283,201,1211,317,909,19649,268,3301,218,198,4013,24872,319,295,582,1331,610,1178,18183,571,738,209,252,282,14],"14":[12331,2208,283,301,198,3108,1293,218,198,4013,19210,17518,201,20690,201,281,4729,307,300,4340,3945,218,198,2392,4774,263,16455,1839,318,249,22701,1411,209,7666,227,295,4568,201,11527,12061,1080,198,1293,398,504,3242,227,307,951,1579,2850,201,867,5265,257,283,1261,4487,14166,523,2215,201,22956,225,19210,18381,215,282,14,3248,1775,277,18903,201,14300,201,225,471,1080,14415,17518,268,83,5278,218,198,4013,398,504,17228,1040,1994,22708,307,2489,297,633,66,302,201,19912,1061,3119,13525,967,218,3108,1293,311,274,19210,282,14,3248,1775,277,18903,201,14300,310,21468,274,19210,17518,201,20690,209,252],"15":[261,3108,7139,218,19210,12331,2208,283,258,196,10529,7769,3941,307,23601,559,3059,299,85,215,221,5629,209,530,1501,340,5197,2719,4800,396,110,318,920,2719,3434,396,110,338,201,996,681,246,1245,201,9704,311,4067,5549,371,1342,1957,4587,218,9901,787,6012,310,201,445,857,322,225,198,1346,222,395,14403,490,1211,504,1777,209,252,261,1512,1221,274,282,14,12331,2208,283,7676,307,198,9758,225,15124,447,8473,301,239,2364,17152,239,201,445,198,1293,301,471,3991,876,281,198,239,1512,17152,239,209,252,238,3180,5577,4593,238,252,1934,8314,12825,3180,5577,4593,201,17920,201,17920,37]}} +""" diff --git a/python/ray/ml/tests/test_huggingface_predictor.py b/python/ray/ml/tests/test_huggingface_predictor.py new file mode 100644 index 000000000000..8818ff808c6c --- /dev/null +++ b/python/ray/ml/tests/test_huggingface_predictor.py @@ -0,0 +1,51 @@ +import pandas as pd +import pytest + +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, +) +from transformers.pipelines import pipeline + +from ray.ml.preprocessor import Preprocessor +from ray.ml.predictors.integrations.huggingface import HuggingFacePredictor + +prompts = pd.DataFrame( + ["Complete me", "And me", "Please complete"], columns=["sentences"] +) + +# We are only testing Casual Language Modeling here + +model_checkpoint = "sshleifer/tiny-gpt2" +tokenizer_checkpoint = "sgugger/gpt2-like-tokenizer" + + +class DummyPreprocessor(Preprocessor): + def transform_batch(self, df): + self._batch_transformed = True + return df + + +@pytest.mark.parametrize("preprocessor", [True, False]) +def test_predict(preprocessor, tmpdir): + if preprocessor: + preprocessor = DummyPreprocessor() + else: + preprocessor = None + model_config = AutoConfig.from_pretrained(model_checkpoint) + model = AutoModelForCausalLM.from_config(model_config) + predictor = HuggingFacePredictor( + pipeline=pipeline( + task="text-generation", + model=model, + tokenizer=AutoTokenizer.from_pretrained(tokenizer_checkpoint), + ), + preprocessor=preprocessor, + ) + + predictions = predictor.predict(prompts) + + assert len(predictions) == 3 + if preprocessor: + assert hasattr(predictor.preprocessor, "_batch_transformed") diff --git a/python/ray/ml/tests/test_huggingface_trainer.py b/python/ray/ml/tests/test_huggingface_trainer.py new file mode 100644 index 000000000000..9b9c382455c5 --- /dev/null +++ b/python/ray/ml/tests/test_huggingface_trainer.py @@ -0,0 +1,99 @@ +import pandas as pd +import pytest + +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, + Trainer, + TrainingArguments, +) + +import ray.data +from ray.ml.train.integrations.huggingface import HuggingFaceTrainer +from ray.ml.predictors.integrations.huggingface import HuggingFacePredictor +from ray.ml.batch_predictor import BatchPredictor + +from ray.ml.tests._huggingface_data import train_data, validation_data + +# 16 first rows of tokenized wikitext-2-raw-v1 training & validation +train_df = pd.read_json(train_data) +validation_df = pd.read_json(validation_data) +prompts = pd.DataFrame( + ["Complete me", "And me", "Please complete"], columns=["sentences"] +) + +# We are only testing Casual Language Modelling here + +model_checkpoint = "sshleifer/tiny-gpt2" +tokenizer_checkpoint = "sgugger/gpt2-like-tokenizer" + + +@pytest.fixture +def ray_start_4_cpus(): + address_info = ray.init(num_cpus=4) + yield address_info + # The code after the yield will run as teardown code. + ray.shutdown() + + +def train_function(train_dataset, eval_dataset=None, **config): + model_config = AutoConfig.from_pretrained(model_checkpoint) + model = AutoModelForCausalLM.from_config(model_config) + training_args = TrainingArguments( + f"{model_checkpoint}-wikitext2", + evaluation_strategy="epoch", + num_train_epochs=config.get("epochs", 3), + learning_rate=2e-5, + weight_decay=0.01, + disable_tqdm=True, + no_cuda=True, + save_strategy=config.get("save_strategy", "no"), + ) + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + return trainer + + +@pytest.mark.parametrize("save_strategy", ["no", "epoch"]) +def test_e2e(ray_start_4_cpus, save_strategy): + ray_train = ray.data.from_pandas(train_df) + ray_validation = ray.data.from_pandas(validation_df) + scaling_config = {"num_workers": 2, "use_gpu": False} + trainer = HuggingFaceTrainer( + trainer_init_per_worker=train_function, + trainer_init_config={"epochs": 3, "save_strategy": save_strategy}, + scaling_config=scaling_config, + datasets={"train": ray_train, "evaluation": ray_validation}, + ) + result = trainer.fit() + + assert result.metrics["epoch"] == 3 + assert result.metrics["training_iteration"] == 3 + assert result.checkpoint + + trainer2 = HuggingFaceTrainer( + trainer_init_per_worker=train_function, + trainer_init_config={"epochs": 4}, # this will train for 1 epoch: 4 - 3 = 1 + scaling_config=scaling_config, + datasets={"train": ray_train, "evaluation": ray_validation}, + resume_from_checkpoint=result.checkpoint, + ) + result2 = trainer2.fit() + + assert result2.metrics["epoch"] == 4 + assert result2.checkpoint + + predictor = BatchPredictor.from_checkpoint( + result2.checkpoint, + HuggingFacePredictor, + task="text-generation", + tokenizer=AutoTokenizer.from_pretrained(tokenizer_checkpoint), + ) + + predictions = predictor.predict(ray.data.from_pandas(prompts)) + assert predictions.count() == 3 diff --git a/python/ray/ml/train/data_parallel_trainer.py b/python/ray/ml/train/data_parallel_trainer.py index 26ad9a747e42..7eafdad20f63 100644 --- a/python/ray/ml/train/data_parallel_trainer.py +++ b/python/ray/ml/train/data_parallel_trainer.py @@ -1,7 +1,7 @@ import inspect import logging from pathlib import Path -from typing import Dict, Callable, List, Optional, Union, TYPE_CHECKING +from typing import Callable, Dict, List, Optional, Union, Type, TYPE_CHECKING import ray from ray import tune @@ -25,6 +25,28 @@ logger = logging.getLogger(__name__) +# TODO(team-ml): Refactor checkpoint management along with Tune. +class _DataParallelCheckpointManager(TuneCheckpointManager): + def on_init(self, preprocessor: Preprocessor): + self.preprocessor = preprocessor + super(_DataParallelCheckpointManager, self).on_init() + + def write_checkpoint(self, checkpoint: Dict): + self.add_tune_checkpoint_id(checkpoint) + + # Add the preprocessor to the checkpoint. + checkpoint[PREPROCESSOR_KEY] = self.preprocessor + + checkpoint_obj = Checkpoint.from_dict(checkpoint) + # If inside a Tune Trainable, then checkpoint with Tune. + with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: + checkpoint_obj.to_directory(path=checkpoint_dir) + + @property + def latest_checkpoint_dir(self) -> Optional[Path]: + raise NotImplementedError + + @DeveloperAPI class DataParallelTrainer(Trainer): """A Trainer for data parallel training. @@ -186,6 +208,10 @@ def __init__(self, train_loop_per_worker, my_backend_config: resume_from_checkpoint: A checkpoint to resume training from. """ + _checkpoint_manager_cls: Type[ + TuneCheckpointManager + ] = _DataParallelCheckpointManager + _scaling_config_allowed_keys = [ "num_workers", "num_cpus_per_worker", @@ -286,7 +312,7 @@ def training_loop(self) -> None: max_retries=0, ) - checkpoint_manager = _DataParallelCheckpointManager() + checkpoint_manager = self._checkpoint_manager_cls() checkpoint_manager.on_init(preprocessor=self.preprocessor) # Start the remote actors. @@ -323,28 +349,6 @@ def training_loop(self) -> None: backend_executor.shutdown() -# TODO(team-ml): Refactor checkpoint management along with Tune. -class _DataParallelCheckpointManager(TuneCheckpointManager): - def on_init(self, preprocessor: Preprocessor): - self.preprocessor = preprocessor - super(_DataParallelCheckpointManager, self).on_init() - - def write_checkpoint(self, checkpoint: Dict): - self.add_tune_checkpoint_id(checkpoint) - - # Add the preprocessor to the checkpoint. - checkpoint[PREPROCESSOR_KEY] = self.preprocessor - - checkpoint_obj = Checkpoint.from_dict(checkpoint) - # If inside a Tune Trainable, then checkpoint with Tune. - with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: - checkpoint_obj.to_directory(path=checkpoint_dir) - - @property - def latest_checkpoint_dir(self) -> Optional[Path]: - raise NotImplementedError - - def _default_dataset_split_fn( dataset_dict: Dict[str, "Dataset"], training_worker_handles: List[ActorHandle] ) -> List[Dict[str, "Dataset"]]: diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py index 09dbb4e99cee..7376cc7b7ff9 100644 --- a/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py +++ b/python/ray/ml/train/integrations/huggingface/huggingface_trainer.py @@ -1,16 +1,88 @@ -from typing import Any, Callable, Optional, Dict +import inspect +import os +import shutil +import tempfile +from distutils.version import LooseVersion +from pathlib import Path +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union +from ray.ml.utils.torch_utils import load_torch_model -from transformers.trainer import Trainer +import torch +import transformers +import transformers.modeling_utils +import transformers.trainer +from transformers.trainer import WEIGHTS_NAME, TRAINING_ARGS_NAME +import transformers.training_args +import ray.cloudpickle as cpickle from torch.utils.data import Dataset as TorchDataset -from ray.train.torch import TorchConfig -from ray.ml.trainer import GenDataset -from ray.ml.train.integrations.torch import TorchTrainer -from ray.ml.config import ScalingConfig, RunConfig -from ray.ml.preprocessor import Preprocessor +from ray import train +from ray import tune +from ray.util import PublicAPI, get_node_ip_address from ray.ml.checkpoint import Checkpoint -from ray.util import PublicAPI -from ray.ml.constants import TRAIN_DATASET_KEY, EVALUATION_DATASET_KEY +from ray.ml.config import RunConfig, ScalingConfig +from ray.ml.constants import EVALUATION_DATASET_KEY, PREPROCESSOR_KEY, TRAIN_DATASET_KEY +from ray.ml.preprocessor import Preprocessor +from ray.ml.train.integrations.torch import TorchTrainer +from ray.ml.trainer import GenDataset +from ray.ml.train.data_parallel_trainer import _DataParallelCheckpointManager +from ray.ml.train.integrations.huggingface.huggingface_utils import ( + CHECKPOINT_PATH_ON_NODE_KEY, + NODE_IP_KEY, + process_datasets, + TrainReportCallback, + wrap_transformers_trainer, +) +from ray.train.constants import TUNE_CHECKPOINT_ID +from ray.train.torch import TorchConfig +from ray.tune.trainable import Trainable +from ray.tune.utils.file_transfer import delete_on_node, sync_dir_between_nodes + +# This trainer uses a special checkpoint syncing logic. +# Because HF checkpoints are very large dirs (at least several GBs), +# we use directory checkpoints that are synced between nodes when +# required instead of serializing the checkpoints and sending +# bytes over nodes. This is a much more performant solution for +# large directory checkpoints. The current implementation +# is special for HuggingFaceTrainer, but can and should be +# made generic. +# TODO(ml-team): Make dir syncing checkpoint logic generic. + + +# The checkpoint is turned into a dict with node ip & path +# in HuggingFaceTrainer.as_trainable +# TODO(team-ml): Refactor checkpoint management along with Tune. +class _DataParallelSyncingCheckpointManager(_DataParallelCheckpointManager): + """As _DataParallelCheckpointManager, but syncs the dir instead of serializing.""" + + def write_checkpoint(self, checkpoint: Dict): + # If inside a Tune Trainable, then checkpoint with Tune. + with tune.checkpoint_dir(step=self._latest_checkpoint_id) as checkpoint_dir: + source_ip = checkpoint[NODE_IP_KEY] + source_path = checkpoint[CHECKPOINT_PATH_ON_NODE_KEY] + target_ip = get_node_ip_address() + if source_ip == target_ip: + # Move contents of source_path, but not source_path + # itself. shutil.move is already recursive. + for path in Path(source_path).iterdir(): + shutil.move(str(path.absolute()), checkpoint_dir) + shutil.rmtree(source_path, ignore_errors=True) + else: + sync_dir_between_nodes( + source_ip=source_ip, + source_path=source_path, + target_ip=target_ip, + target_path=checkpoint_dir, + return_futures=False, + max_size_bytes=None, + ) + delete_on_node(node_ip=source_ip, path=source_path) + checkpoint_dir = Path(checkpoint_dir) + with open(checkpoint_dir.joinpath(PREPROCESSOR_KEY), "wb") as f: + cpickle.dump(self.preprocessor, f) + # add tune checkpoint id + with open(checkpoint_dir.joinpath(TUNE_CHECKPOINT_ID), "w") as f: + f.write(str(self._latest_checkpoint_id)) @PublicAPI(stability="alpha") @@ -36,11 +108,15 @@ class HuggingFaceTrainer(TorchTrainer): the ``get_train_dataloader`` method will be overriden to disable distributed sampling, as the dataset will already be sharded. - Hugging Face loggers will be automatically disabled, and the ``local_rank`` - argument in ``TrainingArguments`` will be automatically set. + HuggingFace loggers will be automatically disabled, and the ``local_rank`` + argument in ``TrainingArguments`` will be automatically set. Please note + that if you want to use CPU training, you will need to set the ``no_cuda`` + argument in ``TrainingArguments`` manually - otherwise, an exception + (segfault) may be thrown. Example: .. code-block:: python + # Based on # huggingface/notebooks/examples/language_modeling_from_scratch.ipynb @@ -131,13 +207,6 @@ def trainer_init_per_worker(train_dataset, eval_dataset, **config): and config as kwargs. The Torch Datasets are automatically created by converting the Ray Datasets internally before they are passed into the function. - trainer_init_config: Configurations to pass into - ``trainer_init_per_worker`` as kwargs. - torch_config: Configuration for setting up the PyTorch backend. If set to - None, use the default configuration. This replaces the ``backend_config`` - arg of ``DataParallelTrainer``. Same as in ``TorchTrainer``. - scaling_config: Configuration for how to scale data parallel training. - run_config: Configuration for the execution of the training run. datasets: Any Ray Datasets to use for training. Use the key "train" to denote which dataset is the training dataset and (optionally) key "evaluation" to denote the evaluation @@ -146,52 +215,302 @@ def trainer_init_per_worker(train_dataset, eval_dataset, **config): If a ``preprocessor`` is provided and has not already been fit, it will be fit on the training dataset. All datasets will be transformed by the ``preprocessor`` if one is provided. + trainer_init_config: Configurations to pass into + ``trainer_init_per_worker`` as kwargs. + torch_config: Configuration for setting up the PyTorch backend. If set to + None, use the default configuration. This replaces the ``backend_config`` + arg of ``DataParallelTrainer``. Same as in ``TorchTrainer``. + scaling_config: Configuration for how to scale data parallel training. + run_config: Configuration for the execution of the training run. preprocessor: A ray.ml.preprocessor.Preprocessor to preprocess the provided datasets. resume_from_checkpoint: A checkpoint to resume training from. """ + _checkpoint_manager_cls = _DataParallelSyncingCheckpointManager + def __init__( self, + *, trainer_init_per_worker: Callable[ - [TorchDataset, Optional[TorchDataset], Any], Trainer + [TorchDataset, Optional[TorchDataset], Any], transformers.trainer.Trainer ], + datasets: Dict[str, GenDataset], trainer_init_config: Optional[Dict] = None, torch_config: Optional[TorchConfig] = None, scaling_config: Optional[ScalingConfig] = None, run_config: Optional[RunConfig] = None, - datasets: Optional[Dict[str, GenDataset]] = None, preprocessor: Optional[Preprocessor] = None, resume_from_checkpoint: Optional[Checkpoint] = None, ): - self._validate_train_loop_per_worker( + # Functionality required for HuggingFaceTrainer only added in this + # version + if LooseVersion(transformers.__version__) < LooseVersion("4.18.0"): + raise RuntimeError( + "HuggingFaceTrainer requires transformers>=4.18.0, but you " + f"have {transformers.__version__} which is incompatible. " + "Update on all nodes with `pip install -U 'transformers>=4.18.0'`." + ) + + self._validate_trainer_init_per_worker( trainer_init_per_worker, "trainer_init_per_worker" ) - assert TRAIN_DATASET_KEY in datasets - assert all( - key in (TRAIN_DATASET_KEY, EVALUATION_DATASET_KEY) for key in datasets - ) + trainer_init_config = trainer_init_config.copy() if trainer_init_config else {} + if "_trainer_init_per_worker" in trainer_init_config: + raise ValueError( + "'_trainer_init_per_worker' is a reserved key in `trainer_init_config`." + ) + trainer_init_config["_trainer_init_per_worker"] = trainer_init_per_worker super().__init__( - self._create_train_func(trainer_init_per_worker), - trainer_init_config, - torch_config, - scaling_config, - run_config, - datasets, - preprocessor, - resume_from_checkpoint, + train_loop_per_worker=_huggingface_train_loop_per_worker, + train_loop_config=trainer_init_config, + torch_config=torch_config, + scaling_config=scaling_config, + run_config=run_config, + datasets=datasets, + preprocessor=preprocessor, + resume_from_checkpoint=resume_from_checkpoint, ) - def _create_train_func(self, trainer_init_per_worker): - def train_loop_per_worker(config): - # Set to None just to make CI pass & show - # the intended usage with trainer_init_per_worker - train_dataset = None - eval_dataset = None - trainer = trainer_init_per_worker(train_dataset, eval_dataset, **config) - trainer.train() + def _validate_trainer_init_per_worker( + self, trainer_init_per_worker: Callable, fn_name: str + ) -> None: + num_params = len(inspect.signature(trainer_init_per_worker).parameters) + if num_params < 3: + raise ValueError( + f"{fn_name} should take in at least 3 arguments, " + f"but it accepts {num_params} arguments instead." + ) - return train_loop_per_worker + def _validate_attributes(self): + # exceptions first + if TRAIN_DATASET_KEY not in self.datasets: + raise KeyError( + f"'{TRAIN_DATASET_KEY}' key must be preset in `datasets`. " + f"Got {list(self.datasets.keys())}" + ) + if not all( + key in (TRAIN_DATASET_KEY, EVALUATION_DATASET_KEY) for key in self.datasets + ): + raise KeyError( + f"Only '{TRAIN_DATASET_KEY}' and '{EVALUATION_DATASET_KEY}' " + "keys can be preset in `datasets`. " + f"Got {list(self.datasets.keys())}" + ) + gpus_per_worker = self.scaling_config.get("num_gpus_per_worker", 0) + if gpus_per_worker > 1: + raise ValueError( + f"You have assigned {gpus_per_worker} GPUs per worker. " + "This is not supported by HuggingFace, which expects " + "one GPU per worker in DDP mode and will fail " + "if more are assigned." + ) + if gpus_per_worker != int(gpus_per_worker): + raise ValueError( + f"You have assigned {gpus_per_worker} GPUs per worker, " + "but fractional GPUs are not supported by HuggingFace." + ) + + super()._validate_attributes() + + def _convert_directory_checkpoint_to_sync_if_needed( + self, checkpoint: Checkpoint + ) -> Checkpoint: + """Replace the directory checkpoint with a node ip & path dict checkpoint. + + This dict checkpoint will be used used to sync the directory. + If we were to use a directory checkpoint directly, it would get deepcopied & + serialized unnecessarily.""" + with checkpoint.as_directory() as checkpoint_path: + # Load checkpoint from path. + checkpoint_path = Path(checkpoint_path).expanduser().absolute() + if not checkpoint_path.joinpath(TUNE_CHECKPOINT_ID).exists(): + # If the ID file is missing, we assume that this is already + # a sync checkpoint + dict_checkpoint = checkpoint.to_dict() + if ( + NODE_IP_KEY not in dict_checkpoint + or CHECKPOINT_PATH_ON_NODE_KEY not in dict_checkpoint + ): + raise ValueError( + "Wrong checkpoint format. Ensure the checkpoint is a " + "result of `HuggingFaceTrainer`." + ) + return checkpoint + with open(checkpoint_path.joinpath(TUNE_CHECKPOINT_ID), "r") as f: + tune_checkpoint_id = int(f.read()) + + return Checkpoint.from_dict( + { + NODE_IP_KEY: get_node_ip_address(), + CHECKPOINT_PATH_ON_NODE_KEY: str(checkpoint_path), + TUNE_CHECKPOINT_ID: tune_checkpoint_id, + } + ) + + def setup(self) -> None: + if self.resume_from_checkpoint: + self.resume_from_checkpoint = ( + self._convert_directory_checkpoint_to_sync_if_needed( + self.resume_from_checkpoint + ) + ) + + def as_trainable(self) -> Type[Trainable]: + original_param_dict = self._param_dict.copy() + resume_from_checkpoint: Optional[Checkpoint] = self._param_dict.get( + "resume_from_checkpoint", None + ) + if resume_from_checkpoint: + self._param_dict[ + "resume_from_checkpoint" + ] = self._convert_directory_checkpoint_to_sync_if_needed( + resume_from_checkpoint + ) + try: + ret = super().as_trainable() + finally: + self._param_dict = original_param_dict + return ret + + @staticmethod + def load_huggingface_checkpoint( + checkpoint: Checkpoint, + model: Union[ + Type[transformers.modeling_utils.PreTrainedModel], torch.nn.Module + ], + tokenizer: Optional[Type[transformers.PreTrainedTokenizer]] = None, + *, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + **pretrained_model_kwargs, + ) -> Tuple[ + Union[transformers.modeling_utils.PreTrainedModel, torch.nn.Module], + transformers.training_args.TrainingArguments, + Optional[transformers.PreTrainedTokenizer], + Optional[Preprocessor], + ]: + """Load a Checkpoint from ``HuggingFaceTrainer``. + + Return the model, ``TrainingArguments``, tokenizer and AIR preprocessor + contained within. Those can be used to initialize a ``transformers.Trainer`` + object locally. + + Args: + checkpoint: The checkpoint to load the model and + preprocessor from. It is expected to be from the result of a + ``HuggingFaceTrainer`` run. + model: Either a ``transformers.PreTrainedModel`` class + (eg. ``AutoModelForCausalLM``), or a PyTorch model to load the + weights to. This should be the same model used for training. + tokenizer: A ``transformers.PreTrainedTokenizer`` class to load + the model tokenizer to. If not specified, the tokenizer will + not be loaded. Will throw an exception if specified, but no + tokenizer was found in the checkpoint. + tokenizer_kwargs: Dict of kwargs to pass to ``tokenizer.from_pretrained`` + call. Ignored if ``tokenizer`` is None. + **pretrained_model_kwargs: Kwargs to pass to ``mode.from_pretrained`` + call. Ignored if ``model`` is not a ``transformers.PreTrainedModel`` + class. + """ + tokenizer_kwargs = tokenizer_kwargs or {} + with checkpoint.as_directory() as checkpoint_path: + preprocessor_path = os.path.join(checkpoint_path, PREPROCESSOR_KEY) + if os.path.exists(preprocessor_path): + with open(preprocessor_path, "rb") as f: + preprocessor = cpickle.load(f) + else: + preprocessor = None + if isinstance(model, torch.nn.Module): + state_dict = torch.load( + os.path.join(checkpoint_path, WEIGHTS_NAME), map_location="cpu" + ) + model = load_torch_model(saved_model=state_dict, model_definition=model) + else: + model = model.from_pretrained( + checkpoint_path, **pretrained_model_kwargs + ) + if tokenizer: + tokenizer = tokenizer.from_pretrained( + checkpoint_path, **tokenizer_kwargs + ) + training_args_path = os.path.join(checkpoint_path, TRAINING_ARGS_NAME) + if os.path.exists(training_args_path): + with open(training_args_path, "rb") as f: + training_args = torch.load(f, map_location="cpu") + else: + training_args = None + return model, training_args, tokenizer, preprocessor + + +def _huggingface_train_loop_per_worker(config): + """Per-worker training loop for HuggingFace Transformers.""" + trainer_init_per_worker = config.pop("_trainer_init_per_worker") + + # Env vars necessary for HF to setup DDP + os.environ["RANK"] = str(train.world_rank()) + os.environ["WORLD_SIZE"] = str(train.world_size()) + os.environ["LOCAL_RANK"] = str(train.local_rank()) + + train_dataset = train.get_dataset_shard(TRAIN_DATASET_KEY) + eval_dataset = train.get_dataset_shard(EVALUATION_DATASET_KEY) + + train_torch_dataset, eval_torch_dataset = process_datasets( + train_dataset, + eval_dataset, + ) + + trainer: transformers.trainer.Trainer = trainer_init_per_worker( + train_torch_dataset, eval_torch_dataset, **config + ) + + if trainer.args.push_to_hub: + raise ValueError( + "`push_to_hub` parameter in `TrainingArgs` is not supported by " + "`HuggingFaceTrainer`. If you would like to push your model to hub " + "after training, use the `HuggingFaceTrainer.load_huggingface_checkpoint`" + " method to obtain the model from a returned checkpoint, and use it to " + "instantiate the `transformers.Trainer` class." + ) + + trainer = wrap_transformers_trainer(trainer) + + # ensure no HF logging callbacks are added + # aside from doubling functionality with our callbacks, + # the Wandb callbacks causes training to freeze + integration_callbacks = transformers.trainer.get_reporting_integration_callbacks( + trainer.args.report_to + ) + for callback in integration_callbacks: + trainer.pop_callback(callback) + + trainer.add_callback(TrainReportCallback) + + checkpoint = train.load_checkpoint() + checkpoint_path = None + remove_checkpoint_path = False + if checkpoint: + source_ip = checkpoint[NODE_IP_KEY] + source_path = checkpoint[CHECKPOINT_PATH_ON_NODE_KEY] + target_ip = get_node_ip_address() + if source_ip == target_ip: + checkpoint_path = source_path + else: + # TODO(yard1): Confirm if tempdir is the right approach here. + checkpoint_path = tempfile.mkdtemp( + suffix=Path(trainer.args.output_dir).name + ) + remove_checkpoint_path = True + sync_dir_between_nodes( + source_ip=source_ip, + source_path=source_path, + target_ip=target_ip, + target_path=checkpoint_path, + return_futures=False, + max_size_bytes=None, + ) + trainer.train(resume_from_checkpoint=checkpoint_path) + if remove_checkpoint_path: + shutil.rmtree(checkpoint_path, ignore_errors=True) diff --git a/python/ray/ml/train/integrations/huggingface/huggingface_utils.py b/python/ray/ml/train/integrations/huggingface/huggingface_utils.py new file mode 100644 index 000000000000..5cc48578f0e5 --- /dev/null +++ b/python/ray/ml/train/integrations/huggingface/huggingface_utils.py @@ -0,0 +1,203 @@ +from pathlib import Path +from typing import Any, Optional, Tuple, Type + +import datasets.iterable_dataset +import transformers.trainer +from torch.utils.data import IterableDataset, DataLoader +from transformers.trainer_callback import TrainerCallback + +from ray import train +from ray.util import get_node_ip_address +from ray.data.dataset import Dataset + +# Constants for the sync checkpoint dict. See huggingface_trainer.py +CHECKPOINT_PATH_ON_NODE_KEY = "checkpoint_path_on_node" +NODE_IP_KEY = "node_ip" + + +def maybe_add_length(obj: Any, length: Optional[int]) -> Any: + """Change the class of obj to a subclass with predefined __len__ if needed.""" + # By adding length to the dataset we let HF calculate steps per epoch + # and other such values. Without length, it's not possible to use + # epochs as the evaluation strategy, which makes for poor UX. + + if not length or hasattr(obj, "__len__"): + return obj + + def __len__(self): + return length + + new_class = type( + f"{obj.__class__.__name__}WithLength", (obj.__class__,), {"__len__": __len__} + ) + obj.__class__ = new_class + return obj + + +def wrap_transformers_trainer( + trainer: transformers.trainer.Trainer, +) -> transformers.trainer.Trainer: + """Change the class of trainer to a subclass implementing Ray-specific logic.""" + base_trainer_class: Type[transformers.trainer.Trainer] = trainer.__class__ + + class RayTrainer(base_trainer_class): + # TODO(yard1): Upstream data collator removing unused columns to + # transformers. + # This is necessary to provide the same experience as with a + # non-iterable HuggingFace Dataset, which can remove columns + # not supported by the model. + def _prepare_data_collator(self): + """Wrap the data collator in a function removing superflous columns.""" + # Hack to set the self._signature_columns attribute. + try: + self._remove_unused_columns(None, description="nan") + except AttributeError: + pass + + if self._signature_columns and not hasattr(self, "_original_data_collator"): + self._original_data_collator = self.data_collator + + def remove_columns_collator(features): + features = [ + { + k: v + for k, v in feature.items() + if k in self._signature_columns + } + for feature in features + ] + return self._original_data_collator(features) + + collator = remove_columns_collator + else: + collator = self.data_collator + + self.data_collator = collator + + def get_train_dataloader(self): + if self.train_dataset is None: + raise ValueError("Trainer: training requires a train_dataset.") + + train_dataset = self.train_dataset + + # While we are not sharding the train dataset again, this + # class ensures that the last batch has a consistent size. + train_dataset = transformers.trainer.IterableDatasetShard( + train_dataset, + batch_size=self.args.train_batch_size, + drop_last=self.args.dataloader_drop_last, + ) + + return DataLoader( + train_dataset, + batch_size=self.args.per_device_train_batch_size, + collate_fn=self.data_collator, + num_workers=self.args.dataloader_num_workers, + pin_memory=self.args.dataloader_pin_memory, + ) + + trainer.__class__ = RayTrainer + trainer._prepare_data_collator() + return trainer + + +# TODO(ml-team): Replace with a Ray Datasets-HuggingFace integration when available. +class RayDatasetHFIterable(datasets.iterable_dataset.ExamplesIterable): + """HF ExamplesIterable backed by a Ray Dataset.""" + + def __init__(self, dataset: Dataset) -> None: + self.dataset = dataset + self.generate_examples_fn = self.dataset.iter_rows + + # Required for the superclass + self.kwargs = {} + + def __iter__(self): + for row in self.generate_examples_fn(**self.kwargs): + yield (0, {k: v for k, v in row.as_pydict().items()}) + + +def process_dataset_for_hf(dataset: Dataset) -> IterableDataset: + """Converts a Ray Dataset into a HF IterableDataset.""" + hf_iterable = RayDatasetHFIterable(dataset) + + iterable_dataset = datasets.iterable_dataset.IterableDataset( + hf_iterable, format_type="torch" + ).with_format("torch") + + try: + dataset_length = dataset.count() + except ValueError: + # pipeline case + dataset_length = None + + iterable_dataset = maybe_add_length(iterable_dataset, dataset_length) + return iterable_dataset + + +def process_datasets( + train_dataset: Dataset, + eval_dataset: Dataset, +) -> Tuple[IterableDataset, IterableDataset]: + """Convert Ray train and validation to HF IterableDatasets.""" + train_torch_dataset = process_dataset_for_hf(train_dataset) + + if eval_dataset: + eval_torch_dataset = process_dataset_for_hf(eval_dataset) + else: + eval_torch_dataset = None + + return train_torch_dataset, eval_torch_dataset + + +class TrainReportCallback(TrainerCallback): + """HF TrainerCallback for Ray Train metric reporting & checkpointing.""" + + def __init__(self) -> None: + # HF first logs metrics, and then checkpoints. With Ray AIR, we need the + # opposite. Furthermore, some metrics are logged just at the end. + # Therefore, if we detect that a checkpoint will be created, + # we delay the train.report call after the checkpoint is reported + # to Ray Train. + self.delayed_report = {} + self.first_report_keys = None + super().__init__() + + def on_step_end(self, args, state, control, **kwargs): + if control.should_training_stop: + # Always save at the end. + control.should_save = True + return control + + def on_log(self, args, state, control, model=None, logs=None, **kwargs): + # Log is called in multiple places (evaluation, train metrics). + report = {**logs, "step": state.global_step, "epoch": state.epoch} + if not self.first_report_keys: + self.first_report_keys = set(report) + # if saving or training end is coming, delay reporting + if not control.should_save and not control.should_training_stop: + train.report(**report) + else: + self.delayed_report.update(report) + + def on_save(self, args, state, control, **kwargs): + # Save is called after evaluation. + checkpoint_path = Path( + transformers.trainer.get_last_checkpoint(args.output_dir) + ).absolute() + if checkpoint_path: + train.save_checkpoint( + **{ + NODE_IP_KEY: get_node_ip_address(), + CHECKPOINT_PATH_ON_NODE_KEY: str(checkpoint_path), + } + ) + if self.delayed_report and not control.should_training_stop: + train.report(**self.delayed_report) + self.delayed_report = {} + + def on_train_end(self, args, state, control, **kwargs): + # Final callback. Train metrics are logged right before this. + if self.delayed_report: + train.report(**self.delayed_report) + self.delayed_report = {} diff --git a/python/ray/ml/utils/torch_utils.py b/python/ray/ml/utils/torch_utils.py index c338ca06cd68..963fca711109 100644 --- a/python/ray/ml/utils/torch_utils.py +++ b/python/ray/ml/utils/torch_utils.py @@ -8,6 +8,7 @@ def convert_pandas_to_torch_tensor( data_batch: pd.DataFrame, columns: Optional[Union[List[str], List[List[str]]]] = None, column_dtypes: Optional[Union[torch.dtype, List[torch.dtype]]] = None, + unsqueeze: bool = True, ) -> Union[torch.Tensor, List[torch.Tensor]]: """Converts a Pandas dataframe to a torch Tensor or list of torch Tensors. @@ -27,6 +28,10 @@ def convert_pandas_to_torch_tensor( column_dtype (Optional[Union[torch.dtype, List[torch.dtype]): The torch dtype to use for the tensor. If set to None, then automatically infer the dtype. + unsqueeze: If set to True, the tensors + will be unsqueezed (reshaped to (N, 1)) before being concatenated into + the final tensor. Otherwise, they will be left as is, that is + (N, ). Defaults to True. Returns: Either a torch tensor of size (N, len(columns)) where N is the @@ -46,6 +51,18 @@ def convert_pandas_to_torch_tensor( columns = columns if columns else [] + def tensorize(vals, dtype): + """This recursive function allows to convert pyarrow List dtypes + to multi-dimensional tensors.""" + try: + return torch.as_tensor(vals, dtype=dtype) + except TypeError: + # This exception will be raised if vals is of object dtype + # or otherwise cannot be made into a tensor directly. + # We assume it's a sequence in that case. + # This is more robust than checking for dtype. + return torch.stack([tensorize(x, dtype) for x in vals]) + def get_tensor_for_columns(columns, dtype): feature_tensors = [] @@ -56,11 +73,14 @@ def get_tensor_for_columns(columns, dtype): for col in batch.columns: col_vals = batch[col].values - t = torch.as_tensor(col_vals, dtype=dtype) - t = t.view(-1, 1) + t = tensorize(col_vals, dtype=dtype) + if unsqueeze: + t = t.unsqueeze(1) feature_tensors.append(t) - return torch.cat(feature_tensors, dim=1) + if len(feature_tensors) > 1: + return torch.cat(feature_tensors, dim=1) + return feature_tensors[0] if multi_input: if type(column_dtypes) not in [list, tuple]: @@ -77,7 +97,7 @@ def load_torch_model( saved_model: Union[torch.nn.Module, Dict], model_definition: Optional[torch.nn.Module] = None, ) -> torch.nn.Module: - """Loads a PyTorch model from the provided``saved_model``. + """Loads a PyTorch model from the provided ``saved_model``. If ``saved_model`` is a torch Module, then return it directly. If ``saved_model`` is a torch state dict, then load it in the ``model_definition`` and return the loaded diff --git a/python/requirements/ml/requirements_train.txt b/python/requirements/ml/requirements_train.txt index 792f79f66147..f1a37ad85c36 100644 --- a/python/requirements/ml/requirements_train.txt +++ b/python/requirements/ml/requirements_train.txt @@ -5,9 +5,11 @@ mlflow==1.21.0 tensorboardX==2.4.1 -# Dependencies for Hugging Face examples: +# Dependencies for Hugging Face examples & tests: # `python/ray/train/examples/transformers/transformers_example.py` -transformers==4.10.0 +# `python/ray/ml/examples/huggingface/huggingface_basic_language_modeling_example.py` +# `python/ray/ml/tests/test_huggingface_trainer.py` +transformers==4.18.0 accelerate==0.5.1 -datasets==1.14.0 +datasets==2.0.0 sentencepiece==0.1.96 diff --git a/python/requirements/ml/requirements_tune.txt b/python/requirements/ml/requirements_tune.txt index 2f4f977a17f6..8c8aaabbdc7d 100644 --- a/python/requirements/ml/requirements_tune.txt +++ b/python/requirements/ml/requirements_tune.txt @@ -34,7 +34,7 @@ scikit-learn==0.24.2 scikit-optimize==0.8.1 sigopt==7.5.0 timm==0.4.5 -transformers==4.10.0 +transformers==4.18.0 wandb==0.12.5 xgboost==1.3.3 zoopt==0.4.1