From 8e4d96a4dc7b82b0a9331415bf4e0ee2372a37a8 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Fri, 17 Jul 2020 14:55:28 -0700 Subject: [PATCH] Add basic ORTTrainer API (#4435) This PR presents the public API for ORTTrainer for the short term development. It also validates and saves input parameters, which will be used in the next stages, such as building ONNX model, post processing the model and configuring the training session --- .../orttraining/python/training/__init__.py | 3 +- .../orttraining/python/training/orttrainer.py | 148 +++++++++++++++++- 2 files changed, 147 insertions(+), 4 deletions(-) diff --git a/orttraining/orttraining/python/training/__init__.py b/orttraining/orttraining/python/training/__init__.py index c5655c39a2bcb..164bed091dc17 100644 --- a/orttraining/orttraining/python/training/__init__.py +++ b/orttraining/orttraining/python/training/__init__.py @@ -5,7 +5,6 @@ from onnxruntime.capi._pybind_state import TrainingParameters from onnxruntime.capi.training.training_session import TrainingSession - from .orttrainer_options import ORTTrainerOptions -from .orttrainer import TrainStepInfo +from .orttrainer import ORTTrainer, TrainStepInfo from . import amp, optim, model_desc_validation diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index 9bbd02b47ff69..1b228bf36ebbe 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -1,4 +1,10 @@ -from .optim import _OptimizerConfig +import onnx +import torch + +from . import ORTTrainerOptions +from . import optim +from .model_desc_validation import _ORTTrainerModelDesc + class TrainStepInfo(object): r"""Private class used to store runtime information from current train step. @@ -17,6 +23,7 @@ class TrainStepInfo(object): optimizer_config (optim._OptimizerConfig): reference to optimizer config Example: + .. code-block:: python info = TrainStepInfo(all_finite=True, step=0, optimizer_config=optim.SGDConfig(lr=0.01)) @@ -30,9 +37,146 @@ def __init__(self, all_finite=None, step=None, optimizer_config=None): "all_finite must be either None or a bool" assert step is None or (isinstance(step, int) and step >= 0),\ "step must be either None or a positive int" - assert optimizer_config is None or isinstance(optimizer_config, _OptimizerConfig),\ + assert optimizer_config is None or isinstance(optimizer_config, optim._OptimizerConfig),\ "optimizer_config must be either None or optim._OptimizerConfig" self.all_finite = all_finite self.step = step self.optimizer_config = optimizer_config + + +class ORTTrainer(object): + r"""Pytorch frontend for ONNX Runtime training + + Entry point that exposes the C++ backend of ORT as a Pytorch frontend. + + Args: + model (torch.nn.Module or onnx.ModelProto): either a PyTorch or ONNX model. + When a PyTorch model and :py:attr:`loss_fn` are specified, :py:attr:`model` and :py:obj:`loss_fn` are combined. + When a ONNX model is provided, the loss is identified by the flag :py:obj:`is_loss=True` in one of the :py:attr:`.model_desc.outputs` entries. + model_desc (dict): model input and output description. + This is used to identify inputs and outputs and their shapes, so that ORT can generate back propagation graph, plan memory allocation for + training, and perform optimizations. + :py:attr:`model_desc` must be consistent with the training :py:attr:`model` and have the following (:py:obj:`dict`) schema + :py:obj:`{ 'inputs': [tuple(name, shape)], 'outputs': [tuple(name, shape, is_loss)]}`. + :py:attr:`name` is a string representing the name of input or output of the model. + For :py:obj:`model_desc['inputs']` entries, :py:attr:`name` must match input names of the original PyTorch model's :py:meth:`torch.nn.Module.forward` method. + For ONNX models, both name and order of input names must match. + For :py:obj:`model_desc['outputs']` entries, the order must match the original PyTorch's output as returned by :py:meth:`torch.nn.Module.forward` method. + For ONNX models, both name and order of output names must match. + :py:attr:`shape` is a list of string or integers that describes the shape of the input/output. + Each dimension size can be either a string or an int. String means the dimension size is dynamic, while integers mean static dimensions. + An empty list implies a scalar. + Lastly, :py:attr:`is_loss` is a boolean (default is False) that flags if this output is considered a loss. + ORT backend needs to know which output is loss in order to generate back propagation graph. + Loss output must be specified when either :py:attr:`loss_fn` is specified or when loss is embedded in the model. + Note that only one loss output is supported per model. + optimizer_config (optim._OptimizerConfig): optimizer config. + One of :py:class:`.optim.AdamConfig`, :py:class:`.optim.LambConfig` or :py:class:`.optim.SGDConfig`. + loss_fn (default is None): a PyTorch loss function. + It takes two inputs [prediction, label] and output a loss tensor. + If provided, :py:attr:`loss_fn` is combined with the PyTorch :py:attr:`model` to form a combined PyTorch model. + Inputs to the combined PyTorch model are concatenation of the :py:attr:`model`'s input and :py:attr:`loss_fn`'s label input. + Outputs of the combined PyTorch model are concatenation of :py:attr:`loss_fn`'s loss output and :py:attr:`model`'s outputs. + options (ORTTrainerOptions, default is None): options for additional features. + + Example: + + .. code-block:: python + + model = ... + model_desc = { + "inputs": [ + ("input_ids", ["batch", "max_seq_len_in_batch"]), + ("attention_mask", ["batch", "max_seq_len_in_batch"]), + ("token_type_ids", ["batch", "max_seq_len_in_batch"]), + ("masked_lm_labels", ["batch", "max_seq_len_in_batch"]), + ("next_sentence_label", ["batch", 1]) + ], + "outputs": [ + ("loss", [], True), + ], + } + optim_config = optim.LambConfig(param_groups = [ { 'params' : ['model_param0'], 'alpha' : 0.8, 'beta' : 0.7}, + { 'params' : ['model_param1' , 'model_param_2'], 'alpha' : 0.0} + ], + alpha=0.9, beta=0.999) + ort_trainer = ORTTrainer(model, model_desc, optim_config) + """ + + def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None): + # Basic validation + assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model" + assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'" + assert isinstance(optim_config, optim._OptimizerConfig),\ + "'optim_config' is required and must be any of 'AdamConfig', 'LambConfig' or 'SGDConfig'" + assert loss_fn is None or isinstance(loss_fn, torch.nn.Module),\ + "'loss_fn' must be either 'None' or 'torch.nn.Module'" + assert options is None or isinstance(options, ORTTrainerOptions),\ + "'loss_fn' must be either 'None' or 'ORTTrainerOptions'" + + # Model + Loss validation + # Supported combinarios are + # ---------------------------------------- + # | | Model | Loss | + # ---------------------------------------- + # | 1 | torch.nn.Module | None | + # | 2 | torch.nn.Module | torch.nn.Module | + # | 3 | ONNX | None | + # ---------------------------------------- + self._torch_model = None + self._onnx_model = None + if isinstance(model, torch.nn.Module): + assert loss_fn is None or isinstance(model, torch.nn.Module),\ + "'loss_fn' must be either 'None' or 'torch.nn.Module'" + self._torch_model = model + self._loss_fn = loss_fn + elif isinstance(model, onnx.ModelProto): + assert loss_fn is None, "'loss_fn' must not be specified when 'model' is an ONNX model" + self._onnx_model = model + self._loss_fn = None + else: + raise ValueError("'model' must be either 'torch.nn.Module' or 'onnx.ModelProto'") + + self.model_desc = _ORTTrainerModelDesc(model_desc) + self.optim_config = optim_config + self.options = ORTTrainerOptions(options) + + def eval_step(self, *input, **kwargs): + r"""Evaluation step method + + Args: + *input: Arbitrary arguments that are used as model input (data only) + **kwargs: Arbitrary keyword arguments that are used as model input (data only) + + Returns: + ordered :py:obj:`list` with model outputs as described by :py:attr:`.ORTTrainer.model_desc` + """ + pass + + def save_as_onnx(self, path): + r"""Persists ONNX model into :py:attr:`path` + + The model will be saved as a Google Protocol Buffers (aka protobuf) file as per ONNX standard containing + the full graph, including inference and training metadata. + + Args: + path (str): Full path, including filename, to save the model in the filesystem + """ + pass + + def train_step(self, *input, **kwargs): + r"""Train step method + + After forward pass, an ordered list with all outputs described at :py:attr:`ORTTrainer.model_desc` is returned. + Additional information relevant to the train step is maintend by :py:attr:`ORTTrainer._train_step_info`. + See :py:class:`.TrainStepInfo` for details. + + Args: + *input: Arbitrary arguments that are used as model input (data only) + **kwargs: Arbitrary keyword arguments that are used as model input (data only) + + Returns: + ordered :py:obj:`list` with model outputs as described by :py:attr:`ORTTrainer.model_desc` + """ + pass