Skip to content

Commit

Permalink
Add basic ORTTrainer API (#4435)
Browse files Browse the repository at this point in the history
This PR presents the public API for ORTTrainer for the short term
development.

It also validates and saves input parameters, which will be used in the
next stages, such as building ONNX model, post processing the model and
configuring the training session
  • Loading branch information
Thiago Crepaldi committed Aug 14, 2020
1 parent ff91cab commit 8e4d96a
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 4 deletions.
3 changes: 1 addition & 2 deletions orttraining/orttraining/python/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from onnxruntime.capi._pybind_state import TrainingParameters
from onnxruntime.capi.training.training_session import TrainingSession


from .orttrainer_options import ORTTrainerOptions
from .orttrainer import TrainStepInfo
from .orttrainer import ORTTrainer, TrainStepInfo
from . import amp, optim, model_desc_validation
148 changes: 146 additions & 2 deletions orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from .optim import _OptimizerConfig
import onnx
import torch

from . import ORTTrainerOptions
from . import optim
from .model_desc_validation import _ORTTrainerModelDesc


class TrainStepInfo(object):
r"""Private class used to store runtime information from current train step.
Expand All @@ -17,6 +23,7 @@ class TrainStepInfo(object):
optimizer_config (optim._OptimizerConfig): reference to optimizer config
Example:
.. code-block:: python
info = TrainStepInfo(all_finite=True, step=0, optimizer_config=optim.SGDConfig(lr=0.01))
Expand All @@ -30,9 +37,146 @@ def __init__(self, all_finite=None, step=None, optimizer_config=None):
"all_finite must be either None or a bool"
assert step is None or (isinstance(step, int) and step >= 0),\
"step must be either None or a positive int"
assert optimizer_config is None or isinstance(optimizer_config, _OptimizerConfig),\
assert optimizer_config is None or isinstance(optimizer_config, optim._OptimizerConfig),\
"optimizer_config must be either None or optim._OptimizerConfig"

self.all_finite = all_finite
self.step = step
self.optimizer_config = optimizer_config


class ORTTrainer(object):
r"""Pytorch frontend for ONNX Runtime training
Entry point that exposes the C++ backend of ORT as a Pytorch frontend.
Args:
model (torch.nn.Module or onnx.ModelProto): either a PyTorch or ONNX model.
When a PyTorch model and :py:attr:`loss_fn` are specified, :py:attr:`model` and :py:obj:`loss_fn` are combined.
When a ONNX model is provided, the loss is identified by the flag :py:obj:`is_loss=True` in one of the :py:attr:`.model_desc.outputs` entries.
model_desc (dict): model input and output description.
This is used to identify inputs and outputs and their shapes, so that ORT can generate back propagation graph, plan memory allocation for
training, and perform optimizations.
:py:attr:`model_desc` must be consistent with the training :py:attr:`model` and have the following (:py:obj:`dict`) schema
:py:obj:`{ 'inputs': [tuple(name, shape)], 'outputs': [tuple(name, shape, is_loss)]}`.
:py:attr:`name` is a string representing the name of input or output of the model.
For :py:obj:`model_desc['inputs']` entries, :py:attr:`name` must match input names of the original PyTorch model's :py:meth:`torch.nn.Module.forward` method.
For ONNX models, both name and order of input names must match.
For :py:obj:`model_desc['outputs']` entries, the order must match the original PyTorch's output as returned by :py:meth:`torch.nn.Module.forward` method.
For ONNX models, both name and order of output names must match.
:py:attr:`shape` is a list of string or integers that describes the shape of the input/output.
Each dimension size can be either a string or an int. String means the dimension size is dynamic, while integers mean static dimensions.
An empty list implies a scalar.
Lastly, :py:attr:`is_loss` is a boolean (default is False) that flags if this output is considered a loss.
ORT backend needs to know which output is loss in order to generate back propagation graph.
Loss output must be specified when either :py:attr:`loss_fn` is specified or when loss is embedded in the model.
Note that only one loss output is supported per model.
optimizer_config (optim._OptimizerConfig): optimizer config.
One of :py:class:`.optim.AdamConfig`, :py:class:`.optim.LambConfig` or :py:class:`.optim.SGDConfig`.
loss_fn (default is None): a PyTorch loss function.
It takes two inputs [prediction, label] and output a loss tensor.
If provided, :py:attr:`loss_fn` is combined with the PyTorch :py:attr:`model` to form a combined PyTorch model.
Inputs to the combined PyTorch model are concatenation of the :py:attr:`model`'s input and :py:attr:`loss_fn`'s label input.
Outputs of the combined PyTorch model are concatenation of :py:attr:`loss_fn`'s loss output and :py:attr:`model`'s outputs.
options (ORTTrainerOptions, default is None): options for additional features.
Example:
.. code-block:: python
model = ...
model_desc = {
"inputs": [
("input_ids", ["batch", "max_seq_len_in_batch"]),
("attention_mask", ["batch", "max_seq_len_in_batch"]),
("token_type_ids", ["batch", "max_seq_len_in_batch"]),
("masked_lm_labels", ["batch", "max_seq_len_in_batch"]),
("next_sentence_label", ["batch", 1])
],
"outputs": [
("loss", [], True),
],
}
optim_config = optim.LambConfig(param_groups = [ { 'params' : ['model_param0'], 'alpha' : 0.8, 'beta' : 0.7},
{ 'params' : ['model_param1' , 'model_param_2'], 'alpha' : 0.0}
],
alpha=0.9, beta=0.999)
ort_trainer = ORTTrainer(model, model_desc, optim_config)
"""

def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None):
# Basic validation
assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model"
assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'"
assert isinstance(optim_config, optim._OptimizerConfig),\
"'optim_config' is required and must be any of 'AdamConfig', 'LambConfig' or 'SGDConfig'"
assert loss_fn is None or isinstance(loss_fn, torch.nn.Module),\
"'loss_fn' must be either 'None' or 'torch.nn.Module'"
assert options is None or isinstance(options, ORTTrainerOptions),\
"'loss_fn' must be either 'None' or 'ORTTrainerOptions'"

# Model + Loss validation
# Supported combinarios are
# ----------------------------------------
# | | Model | Loss |
# ----------------------------------------
# | 1 | torch.nn.Module | None |
# | 2 | torch.nn.Module | torch.nn.Module |
# | 3 | ONNX | None |
# ----------------------------------------
self._torch_model = None
self._onnx_model = None
if isinstance(model, torch.nn.Module):
assert loss_fn is None or isinstance(model, torch.nn.Module),\
"'loss_fn' must be either 'None' or 'torch.nn.Module'"
self._torch_model = model
self._loss_fn = loss_fn
elif isinstance(model, onnx.ModelProto):
assert loss_fn is None, "'loss_fn' must not be specified when 'model' is an ONNX model"
self._onnx_model = model
self._loss_fn = None
else:
raise ValueError("'model' must be either 'torch.nn.Module' or 'onnx.ModelProto'")

self.model_desc = _ORTTrainerModelDesc(model_desc)
self.optim_config = optim_config
self.options = ORTTrainerOptions(options)

def eval_step(self, *input, **kwargs):
r"""Evaluation step method
Args:
*input: Arbitrary arguments that are used as model input (data only)
**kwargs: Arbitrary keyword arguments that are used as model input (data only)
Returns:
ordered :py:obj:`list` with model outputs as described by :py:attr:`.ORTTrainer.model_desc`
"""
pass

def save_as_onnx(self, path):
r"""Persists ONNX model into :py:attr:`path`
The model will be saved as a Google Protocol Buffers (aka protobuf) file as per ONNX standard containing
the full graph, including inference and training metadata.
Args:
path (str): Full path, including filename, to save the model in the filesystem
"""
pass

def train_step(self, *input, **kwargs):
r"""Train step method
After forward pass, an ordered list with all outputs described at :py:attr:`ORTTrainer.model_desc` is returned.
Additional information relevant to the train step is maintend by :py:attr:`ORTTrainer._train_step_info`.
See :py:class:`.TrainStepInfo` for details.
Args:
*input: Arbitrary arguments that are used as model input (data only)
**kwargs: Arbitrary keyword arguments that are used as model input (data only)
Returns:
ordered :py:obj:`list` with model outputs as described by :py:attr:`ORTTrainer.model_desc`
"""
pass

0 comments on commit 8e4d96a

Please sign in to comment.