diff --git a/orttraining/orttraining/python/training/__init__.py b/orttraining/orttraining/python/training/__init__.py index c5655c39a2bcb..164bed091dc17 100644 --- a/orttraining/orttraining/python/training/__init__.py +++ b/orttraining/orttraining/python/training/__init__.py @@ -5,7 +5,6 @@ from onnxruntime.capi._pybind_state import TrainingParameters from onnxruntime.capi.training.training_session import TrainingSession - from .orttrainer_options import ORTTrainerOptions -from .orttrainer import TrainStepInfo +from .orttrainer import ORTTrainer, TrainStepInfo from . import amp, optim, model_desc_validation diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index 9bbd02b47ff69..1b228bf36ebbe 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -1,4 +1,10 @@ -from .optim import _OptimizerConfig +import onnx +import torch + +from . import ORTTrainerOptions +from . import optim +from .model_desc_validation import _ORTTrainerModelDesc + class TrainStepInfo(object): r"""Private class used to store runtime information from current train step. @@ -17,6 +23,7 @@ class TrainStepInfo(object): optimizer_config (optim._OptimizerConfig): reference to optimizer config Example: + .. code-block:: python info = TrainStepInfo(all_finite=True, step=0, optimizer_config=optim.SGDConfig(lr=0.01)) @@ -30,9 +37,146 @@ def __init__(self, all_finite=None, step=None, optimizer_config=None): "all_finite must be either None or a bool" assert step is None or (isinstance(step, int) and step >= 0),\ "step must be either None or a positive int" - assert optimizer_config is None or isinstance(optimizer_config, _OptimizerConfig),\ + assert optimizer_config is None or isinstance(optimizer_config, optim._OptimizerConfig),\ "optimizer_config must be either None or optim._OptimizerConfig" self.all_finite = all_finite self.step = step self.optimizer_config = optimizer_config + + +class ORTTrainer(object): + r"""Pytorch frontend for ONNX Runtime training + + Entry point that exposes the C++ backend of ORT as a Pytorch frontend. + + Args: + model (torch.nn.Module or onnx.ModelProto): either a PyTorch or ONNX model. + When a PyTorch model and :py:attr:`loss_fn` are specified, :py:attr:`model` and :py:obj:`loss_fn` are combined. + When a ONNX model is provided, the loss is identified by the flag :py:obj:`is_loss=True` in one of the :py:attr:`.model_desc.outputs` entries. + model_desc (dict): model input and output description. + This is used to identify inputs and outputs and their shapes, so that ORT can generate back propagation graph, plan memory allocation for + training, and perform optimizations. + :py:attr:`model_desc` must be consistent with the training :py:attr:`model` and have the following (:py:obj:`dict`) schema + :py:obj:`{ 'inputs': [tuple(name, shape)], 'outputs': [tuple(name, shape, is_loss)]}`. + :py:attr:`name` is a string representing the name of input or output of the model. + For :py:obj:`model_desc['inputs']` entries, :py:attr:`name` must match input names of the original PyTorch model's :py:meth:`torch.nn.Module.forward` method. + For ONNX models, both name and order of input names must match. + For :py:obj:`model_desc['outputs']` entries, the order must match the original PyTorch's output as returned by :py:meth:`torch.nn.Module.forward` method. + For ONNX models, both name and order of output names must match. + :py:attr:`shape` is a list of string or integers that describes the shape of the input/output. + Each dimension size can be either a string or an int. String means the dimension size is dynamic, while integers mean static dimensions. + An empty list implies a scalar. + Lastly, :py:attr:`is_loss` is a boolean (default is False) that flags if this output is considered a loss. + ORT backend needs to know which output is loss in order to generate back propagation graph. + Loss output must be specified when either :py:attr:`loss_fn` is specified or when loss is embedded in the model. + Note that only one loss output is supported per model. + optimizer_config (optim._OptimizerConfig): optimizer config. + One of :py:class:`.optim.AdamConfig`, :py:class:`.optim.LambConfig` or :py:class:`.optim.SGDConfig`. + loss_fn (default is None): a PyTorch loss function. + It takes two inputs [prediction, label] and output a loss tensor. + If provided, :py:attr:`loss_fn` is combined with the PyTorch :py:attr:`model` to form a combined PyTorch model. + Inputs to the combined PyTorch model are concatenation of the :py:attr:`model`'s input and :py:attr:`loss_fn`'s label input. + Outputs of the combined PyTorch model are concatenation of :py:attr:`loss_fn`'s loss output and :py:attr:`model`'s outputs. + options (ORTTrainerOptions, default is None): options for additional features. + + Example: + + .. code-block:: python + + model = ... + model_desc = { + "inputs": [ + ("input_ids", ["batch", "max_seq_len_in_batch"]), + ("attention_mask", ["batch", "max_seq_len_in_batch"]), + ("token_type_ids", ["batch", "max_seq_len_in_batch"]), + ("masked_lm_labels", ["batch", "max_seq_len_in_batch"]), + ("next_sentence_label", ["batch", 1]) + ], + "outputs": [ + ("loss", [], True), + ], + } + optim_config = optim.LambConfig(param_groups = [ { 'params' : ['model_param0'], 'alpha' : 0.8, 'beta' : 0.7}, + { 'params' : ['model_param1' , 'model_param_2'], 'alpha' : 0.0} + ], + alpha=0.9, beta=0.999) + ort_trainer = ORTTrainer(model, model_desc, optim_config) + """ + + def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None): + # Basic validation + assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model" + assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'" + assert isinstance(optim_config, optim._OptimizerConfig),\ + "'optim_config' is required and must be any of 'AdamConfig', 'LambConfig' or 'SGDConfig'" + assert loss_fn is None or isinstance(loss_fn, torch.nn.Module),\ + "'loss_fn' must be either 'None' or 'torch.nn.Module'" + assert options is None or isinstance(options, ORTTrainerOptions),\ + "'loss_fn' must be either 'None' or 'ORTTrainerOptions'" + + # Model + Loss validation + # Supported combinarios are + # ---------------------------------------- + # | | Model | Loss | + # ---------------------------------------- + # | 1 | torch.nn.Module | None | + # | 2 | torch.nn.Module | torch.nn.Module | + # | 3 | ONNX | None | + # ---------------------------------------- + self._torch_model = None + self._onnx_model = None + if isinstance(model, torch.nn.Module): + assert loss_fn is None or isinstance(model, torch.nn.Module),\ + "'loss_fn' must be either 'None' or 'torch.nn.Module'" + self._torch_model = model + self._loss_fn = loss_fn + elif isinstance(model, onnx.ModelProto): + assert loss_fn is None, "'loss_fn' must not be specified when 'model' is an ONNX model" + self._onnx_model = model + self._loss_fn = None + else: + raise ValueError("'model' must be either 'torch.nn.Module' or 'onnx.ModelProto'") + + self.model_desc = _ORTTrainerModelDesc(model_desc) + self.optim_config = optim_config + self.options = ORTTrainerOptions(options) + + def eval_step(self, *input, **kwargs): + r"""Evaluation step method + + Args: + *input: Arbitrary arguments that are used as model input (data only) + **kwargs: Arbitrary keyword arguments that are used as model input (data only) + + Returns: + ordered :py:obj:`list` with model outputs as described by :py:attr:`.ORTTrainer.model_desc` + """ + pass + + def save_as_onnx(self, path): + r"""Persists ONNX model into :py:attr:`path` + + The model will be saved as a Google Protocol Buffers (aka protobuf) file as per ONNX standard containing + the full graph, including inference and training metadata. + + Args: + path (str): Full path, including filename, to save the model in the filesystem + """ + pass + + def train_step(self, *input, **kwargs): + r"""Train step method + + After forward pass, an ordered list with all outputs described at :py:attr:`ORTTrainer.model_desc` is returned. + Additional information relevant to the train step is maintend by :py:attr:`ORTTrainer._train_step_info`. + See :py:class:`.TrainStepInfo` for details. + + Args: + *input: Arbitrary arguments that are used as model input (data only) + **kwargs: Arbitrary keyword arguments that are used as model input (data only) + + Returns: + ordered :py:obj:`list` with model outputs as described by :py:attr:`ORTTrainer.model_desc` + """ + pass