Skip to content

Commit

Permalink
Update ModelDescription and minor fix on ORTTrainer ctor (#4605)
Browse files Browse the repository at this point in the history
* Update ModelDescription and minor fix on ORTTrainer/ORTTrainerOptions

This PR keeps the public API intact, but changes how model description is stored on the backend

Currently, users creates a dict with two lists of tuples.
One list called 'inputs' and each tuple has the following format tuple(name, shape).
The second list is called 'outputs' and each tuple can be either tuple(name, shape) or tuple(name, shape, is_loss).

With this PR, when this dict is passed in to ORTTrainer, it is fully validated as usual.
However, tuples are internally replaced by namedtuples and all output tuples will have
tuple(name, shape, is_loss) format instead of is_loss being optionally present.

Additionally to that normalization in the internal representation (which eases coding),
two internal methods were created to replace a namedtuple(name, shape) to namedtuple(name, shape, dtype)
or namedtuple(name, shape, is_loss, dtype) dependeing whether the tuple is an input or output.

This is necessary as ORTTRainer finds out data types of each input/output during model export to onnx.

Finally, a minor fix was done on ORTTrainer. It could initialize ORTTrainerOptions incorrectly when options=None

* Rename input name for test
  • Loading branch information
Thiago Crepaldi committed Aug 14, 2020
1 parent 7b977f0 commit 462c357
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 16 deletions.
50 changes: 46 additions & 4 deletions orttraining/orttraining/python/training/model_desc_validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import cerberus

from collections import namedtuple
import torch
from ._utils import static_vars


Expand All @@ -24,29 +25,69 @@ def __init__(self, model_desc):
if self._validated is None:
raise ValueError(f'Invalid model_desc: {validator.errors}')

# Normalize inputs to a list of namedtuple(name, shape)
self._InputDescription = namedtuple('InputDescription', ['name', 'shape'])
self._InputDescriptionTyped = namedtuple('InputDescriptionTyped', ['name', 'shape', 'dtype'])
for idx, input in enumerate(self._validated['inputs']):
self._validated['inputs'][idx] = self._InputDescription(*input)

# Normalize outputs to a list of namedtuple(name, shape, is_loss)
self._OutputDescription = namedtuple('OutputDescription', ['name', 'shape', 'is_loss'])
self._OutputDescriptionTyped = namedtuple('OutputDescriptionTyped', ['name', 'shape', 'is_loss', 'dtype'])
for idx, output in enumerate(self._validated['outputs']):
if len(output) == 2:
self._validated['outputs'][idx] = self._OutputDescription(*output, False)
else:
self._validated['outputs'][idx] = self._OutputDescription(*output)

# Convert dict in object
for k, v in self._validated.items():
setattr(self, k, self._wrap(v))

# Keep this in the last line
# After this point, this class becomes immutable
# NOTE: The embedded lists are still muttable
self._initialized = True

def __repr__(self):
return '{%s}' % str(', '.join("'%s': %s" % (k, repr(v))
for (k, v) in self.__dict__.items()
if k not in ['_main_class_name', '_original', '_validated', '_initialized']))
if k not in ['_main_class_name', '_original', '_validated',
'_InputDescription', '_InputDescriptionTyped',
'_OutputDescription', '_OutputDescriptionTyped']))

def __setattr__(self, k, v):
if hasattr(self, '_initialized'):
raise Exception(f"{self._main_class_name} is an immutable class")
return super().__setattr__(k, v)

def _wrap(self, v):
if isinstance(v, (tuple, list, set, frozenset)):
if isinstance(v, (list)):
return type(v)([self._wrap(v) for v in v])
else:
elif isinstance(v, (self._InputDescription, self._InputDescriptionTyped,
self._OutputDescription, self._OutputDescriptionTyped)):
return v
elif isinstance(v, (tuple)):
return type(v)([self._wrap(v) for v in v])
elif isinstance(v, (dict, int, float, bool, str)):
return _ORTTrainerModelDescInternal(self._main_class_name, v) if isinstance(v, dict) else v
else:
raise ValueError("Unsupported type for model_desc."
"Only int, float, bool, str, list, tuple and dict are supported")

def add_type_to_input_description(self, index, dtype):
assert isinstance(index, int) and index >= 0,\
"input 'index' must be a positive int"
assert isinstance(dtype, torch.dtype),\
"input 'dtype' must be a torch.dtype type"
self.inputs[index] = self._InputDescriptionTyped(*self.inputs[index], dtype)

def add_type_to_output_description(self, index, dtype):
assert isinstance(index, int) and index >= 0,\
"output 'index' must be a positive int"
assert isinstance(dtype, torch.dtype),\
"output 'dtype' must be a torch.dtype type"
self.outputs[index] = self._OutputDescriptionTyped(*self.outputs[index], dtype)


class _ORTTrainerModelDescInternal(_ORTTrainerModelDesc):
Expand All @@ -65,6 +106,7 @@ def __init__(self, main_class_name, model_desc):

# Keep this in the last line
# After this point, this class becomes immutable
# NOTE: The embedded lists are still muttable
self._initialized = True


Expand Down
6 changes: 5 additions & 1 deletion orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from . import optim
from .model_desc_validation import _ORTTrainerModelDesc


class TrainStepInfo(object):
r"""Private class used to store runtime information from current train step.
Expand Down Expand Up @@ -141,7 +142,10 @@ def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None):

self.model_desc = _ORTTrainerModelDesc(model_desc)
self.optim_config = optim_config
self.options = ORTTrainerOptions(options)
if options:
self.options = ORTTrainerOptions(options)
else:
self.options = ORTTrainerOptions()

def eval_step(self, *input, **kwargs):
r"""Evaluation step method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class ORTTrainerOptions(object):
fp16_enabled = opts.mixed_precision.enabled
"""

def __init__(self, options):
def __init__(self, options={}):
# Keep a copy of original input for debug
self._original_opts = dict(options)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,51 @@ def testORTTrainerOptionsInvalidMixedPrecisionEnabledSchema():
assert str(e.value) == expected_msg


@pytest.mark.parametrize("test_input", [
@pytest.mark.parametrize("input_dict,input_dtype,output_dtype", [
({'inputs': [('in0', [])],
'outputs': [('out0', []), ('out1', [])]}),
'outputs': [('out0', []), ('out1', [])]},(torch.int,),(torch.float,torch.int32,)),
({'inputs': [('in0', ['batch', 2, 3])],
'outputs': [('out0', [], True)]}),
'outputs': [('out0', [], True)]}, (torch.int8,), (torch.int16,)),
({'inputs': [('in0', []), ('in1', [1]), ('in2', [1, 2]), ('in3', [1000, 'dyn_ax1']), ('in4', ['dyn_ax1', 'dyn_ax2', 'dyn_ax3'])],
'outputs': [('out0', [], True), ('out1', [1], False), ('out2', [1, 'dyn_ax1', 3])]})
'outputs': [('out0', [], True), ('out1', [1], False), ('out2', [1, 'dyn_ax1', 3])]},
(torch.float,torch.uint8,torch.bool,torch.double,torch.half,), (torch.float,torch.float,torch.int64))
])
def testORTTrainerModelDescValidSchemas(test_input):
def testORTTrainerModelDescValidSchemas(input_dict, input_dtype, output_dtype):
r''' Test different ways of using default values for incomplete input'''
md_val._ORTTrainerModelDesc(test_input)


@pytest.mark.parametrize("test_input,error_msg", [
# Validating model description from user
model_description = md_val._ORTTrainerModelDesc(input_dict)
for idx, i_desc in enumerate(model_description.inputs):
assert isinstance(i_desc, model_description._InputDescription)
assert len(i_desc) == 2
assert input_dict['inputs'][idx][0] == i_desc.name
assert input_dict['inputs'][idx][1] == i_desc.shape
for idx, o_desc in enumerate(model_description.outputs):
assert isinstance(o_desc, model_description._OutputDescription)
assert len(o_desc) == 3
assert input_dict['outputs'][idx][0] == o_desc.name
assert input_dict['outputs'][idx][1] == o_desc.shape
is_loss = input_dict['outputs'][idx][2] if len(input_dict['outputs'][idx]) == 3 else False
assert is_loss == o_desc.is_loss

# Append type to inputs/outputs tuples
for idx, i_desc in enumerate(model_description.inputs):
model_description.add_type_to_input_description(idx, input_dtype[idx])
for idx, o_desc in enumerate(model_description.outputs):
model_description.add_type_to_output_description(idx, output_dtype[idx])

# Verify inputs/outputs tuples are replaced by the typed counterparts
for idx, i_desc in enumerate(model_description.inputs):
assert len(i_desc) == 3
assert isinstance(i_desc, model_description._InputDescriptionTyped)
assert input_dtype[idx] == i_desc.dtype
for idx, o_desc in enumerate(model_description.outputs):
assert len(o_desc) == 4
assert isinstance(o_desc, model_description._OutputDescriptionTyped)
assert output_dtype[idx] == o_desc.dtype


@pytest.mark.parametrize("input_dict,error_msg", [
({'inputs': [(True, [])],
'outputs': [(True, [])]},
"Invalid model_desc: {'inputs': [{0: ['the first element of the tuple (aka name) must be a string']}], "
Expand All @@ -97,11 +128,17 @@ def testORTTrainerModelDescValidSchemas(test_input):
({'inputs': [('in1', [])],
'outputs': [('out1', [], True), ('out2', [], True)]},
"Invalid model_desc: {'outputs': [{1: ['only one is_loss can bet set to True']}]}"),
({'inputz': [('in1', [])],
'outputs': [('out1', [], True)]},
"Invalid model_desc: {'inputs': ['required field'], 'inputz': ['unknown field']}"),
({'inputs': [('in1', [])],
'outputz': [('out1', [], True)]},
"Invalid model_desc: {'outputs': ['required field'], 'outputz': ['unknown field']}"),
])
def testORTTrainerModelDescInvalidSchemas(test_input, error_msg):
def testORTTrainerModelDescInvalidSchemas(input_dict, error_msg):
r''' Test different ways of using default values for incomplete input'''
with pytest.raises(ValueError) as e:
md_val._ORTTrainerModelDesc(test_input)
md_val._ORTTrainerModelDesc(input_dict)
assert str(e.value) == error_msg


Expand Down

0 comments on commit 462c357

Please sign in to comment.