Skip to content

Commit

Permalink
Add opset_version into ORTTrainerOptions and change type of ORTTraine…
Browse files Browse the repository at this point in the history
…r.loss_fn (#4592)
  • Loading branch information
Thiago Crepaldi committed Aug 14, 2020
1 parent 8e4d96a commit 7b977f0
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 10 deletions.
14 changes: 8 additions & 6 deletions orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import onnx
import torch
from inspect import signature

from . import ORTTrainerOptions
from . import optim
from .model_desc_validation import _ORTTrainerModelDesc


class TrainStepInfo(object):
r"""Private class used to store runtime information from current train step.
Expand Down Expand Up @@ -73,8 +73,8 @@ class ORTTrainer(object):
Note that only one loss output is supported per model.
optimizer_config (optim._OptimizerConfig): optimizer config.
One of :py:class:`.optim.AdamConfig`, :py:class:`.optim.LambConfig` or :py:class:`.optim.SGDConfig`.
loss_fn (default is None): a PyTorch loss function.
It takes two inputs [prediction, label] and output a loss tensor.
loss_fn (callable, default is None): a PyTorch loss function.
It takes two inputs [prediction, label] and outputs a scalar loss tensor.
If provided, :py:attr:`loss_fn` is combined with the PyTorch :py:attr:`model` to form a combined PyTorch model.
Inputs to the combined PyTorch model are concatenation of the :py:attr:`model`'s input and :py:attr:`loss_fn`'s label input.
Outputs of the combined PyTorch model are concatenation of :py:attr:`loss_fn`'s loss output and :py:attr:`model`'s outputs.
Expand All @@ -85,6 +85,7 @@ class ORTTrainer(object):
.. code-block:: python
model = ...
loss_fn = ...
model_desc = {
"inputs": [
("input_ids", ["batch", "max_seq_len_in_batch"]),
Expand All @@ -101,7 +102,7 @@ class ORTTrainer(object):
{ 'params' : ['model_param1' , 'model_param_2'], 'alpha' : 0.0}
],
alpha=0.9, beta=0.999)
ort_trainer = ORTTrainer(model, model_desc, optim_config)
ort_trainer = ORTTrainer(model, model_desc, optim_config, loss_fn)
"""

def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None):
Expand All @@ -110,8 +111,8 @@ def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None):
assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'"
assert isinstance(optim_config, optim._OptimizerConfig),\
"'optim_config' is required and must be any of 'AdamConfig', 'LambConfig' or 'SGDConfig'"
assert loss_fn is None or isinstance(loss_fn, torch.nn.Module),\
"'loss_fn' must be either 'None' or 'torch.nn.Module'"
assert loss_fn is None or (callable(loss_fn) and len(signature(loss_fn).parameters) == 2),\
"'loss_fn' must be either 'None' or a callable with two parameters"
assert options is None or isinstance(options, ORTTrainerOptions),\
"'loss_fn' must be either 'None' or 'ORTTrainerOptions'"

Expand Down Expand Up @@ -165,6 +166,7 @@ def save_as_onnx(self, path):
"""
pass


def train_step(self, *input, **kwargs):
r"""Train step method
Expand Down
19 changes: 16 additions & 3 deletions orttraining/orttraining/python/training/orttrainer_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,15 @@ class ORTTrainerOptions(object):
'default' : True
},
'extra_postprocess' : {
'check_with' : 'callable',
'type' : 'callable',
'nullable' : True,
'default' : None
},
'onnx_opset_version': {
'type': 'integer',
'min' : 10,
'max' : 12,
'default': 12
}
}
}
Expand Down Expand Up @@ -187,12 +193,14 @@ class ORTTrainerOptions(object):
utils.grad_norm_clip (bool, default is False):
enables gradient norm clipping for 'AdamOptimizer' and 'LambOptimizer'
_internal_use (dict):
internal, possibly undocumented, options that might be removed in the next release
internal options, possibly undocumented, that might be removed without notice
_internal_use.enable_internal_postprocess (bool, default is True):
enable internal internal post processing of the ONNX model
_internal_use.extra_postprocess (callable, default is None)
a functor to postprocess the ONNX model.
It does not override :py:attr:`._internal_use.enable_internal_postprocess`, but complement it
_internal_use.onnx_opset_version (int, default is 12):
ONNX opset version used during model exporting.
Example:
.. code-block:: python
Expand Down Expand Up @@ -413,7 +421,12 @@ def _check_is_callable(field, value, error):
'check_with': _check_is_callable,
'nullable': True,
'default': None

},
'onnx_opset_version': {
'type': 'integer',
'min' : 10,
'max' : 12,
'default': 12
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def testORTTrainerOptionsDefaultValues(test_input):
},
'_internal_use': {
'enable_internal_postprocess': True,
'extra_postprocess': None
'extra_postprocess': None,
'onnx_opset_version' : 12
}
}

Expand Down

0 comments on commit 7b977f0

Please sign in to comment.