Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ModelDescription and minor fix on ORTTrainer ctor #4605

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 46 additions & 4 deletions orttraining/orttraining/python/training/model_desc_validation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import cerberus

from collections import namedtuple
import torch
from ._utils import static_vars


Expand All @@ -24,29 +25,69 @@ def __init__(self, model_desc):
if self._validated is None:
raise ValueError(f'Invalid model_desc: {validator.errors}')

# Normalize inputs to a list of namedtuple(name, shape)
self._InputDescription = namedtuple('InputDescription', ['name', 'shape'])
self._InputDescriptionTyped = namedtuple('InputDescriptionTyped', ['name', 'shape', 'dtype'])
for idx, input in enumerate(self._validated['inputs']):
self._validated['inputs'][idx] = self._InputDescription(*input)

# Normalize outputs to a list of namedtuple(name, shape, is_loss)
self._OutputDescription = namedtuple('OutputDescription', ['name', 'shape', 'is_loss'])
self._OutputDescriptionTyped = namedtuple('OutputDescriptionTyped', ['name', 'shape', 'is_loss', 'dtype'])
for idx, output in enumerate(self._validated['outputs']):
if len(output) == 2:
self._validated['outputs'][idx] = self._OutputDescription(*output, False)
else:
self._validated['outputs'][idx] = self._OutputDescription(*output)

# Convert dict in object
for k, v in self._validated.items():
setattr(self, k, self._wrap(v))

# Keep this in the last line
# After this point, this class becomes immutable
# NOTE: The embedded lists are still muttable
self._initialized = True

def __repr__(self):
return '{%s}' % str(', '.join("'%s': %s" % (k, repr(v))
for (k, v) in self.__dict__.items()
if k not in ['_main_class_name', '_original', '_validated', '_initialized']))
if k not in ['_main_class_name', '_original', '_validated',
'_InputDescription', '_InputDescriptionTyped',
'_OutputDescription', '_OutputDescriptionTyped']))

def __setattr__(self, k, v):
if hasattr(self, '_initialized'):
raise Exception(f"{self._main_class_name} is an immutable class")
return super().__setattr__(k, v)

def _wrap(self, v):
if isinstance(v, (tuple, list, set, frozenset)):
if isinstance(v, (list)):
return type(v)([self._wrap(v) for v in v])
else:
elif isinstance(v, (self._InputDescription, self._InputDescriptionTyped,
self._OutputDescription, self._OutputDescriptionTyped)):
return v
elif isinstance(v, (tuple)):
return type(v)([self._wrap(v) for v in v])
elif isinstance(v, (dict, int, float, bool, str)):
return _ORTTrainerModelDescInternal(self._main_class_name, v) if isinstance(v, dict) else v
else:
raise ValueError("Unsupported type for model_desc."
"Only int, float, bool, str, list, tuple and dict are supported")

def add_type_to_input_description(self, index, dtype):
assert isinstance(index, int) and index >= 0,\
"input 'index' must be a positive int"
assert isinstance(dtype, torch.dtype),\
"input 'dtype' must be a torch.dtype type"
self.inputs[index] = self._InputDescriptionTyped(*self.inputs[index], dtype)

def add_type_to_output_description(self, index, dtype):
assert isinstance(index, int) and index >= 0,\
"output 'index' must be a positive int"
assert isinstance(dtype, torch.dtype),\
"output 'dtype' must be a torch.dtype type"
self.outputs[index] = self._OutputDescriptionTyped(*self.outputs[index], dtype)


class _ORTTrainerModelDescInternal(_ORTTrainerModelDesc):
Expand All @@ -65,6 +106,7 @@ def __init__(self, main_class_name, model_desc):

# Keep this in the last line
# After this point, this class becomes immutable
# NOTE: The embedded lists are still muttable
self._initialized = True


Expand Down
6 changes: 5 additions & 1 deletion orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from . import optim
from .model_desc_validation import _ORTTrainerModelDesc


class TrainStepInfo(object):
r"""Private class used to store runtime information from current train step.

Expand Down Expand Up @@ -141,7 +142,10 @@ def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None):

self.model_desc = _ORTTrainerModelDesc(model_desc)
self.optim_config = optim_config
self.options = ORTTrainerOptions(options)
if options:
self.options = ORTTrainerOptions(options)
else:
self.options = ORTTrainerOptions()

def eval_step(self, *input, **kwargs):
r"""Evaluation step method
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ class ORTTrainerOptions(object):
fp16_enabled = opts.mixed_precision.enabled
"""

def __init__(self, options):
def __init__(self, options={}):
# Keep a copy of original input for debug
self._original_opts = dict(options)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,20 +65,51 @@ def testORTTrainerOptionsInvalidMixedPrecisionEnabledSchema():
assert str(e.value) == expected_msg


@pytest.mark.parametrize("test_input", [
@pytest.mark.parametrize("test_data,input_dtype,output_dtype", [
({'inputs': [('in0', [])],
'outputs': [('out0', []), ('out1', [])]}),
'outputs': [('out0', []), ('out1', [])]},(torch.int,),(torch.float,torch.int32,)),
({'inputs': [('in0', ['batch', 2, 3])],
'outputs': [('out0', [], True)]}),
'outputs': [('out0', [], True)]}, (torch.int8,), (torch.int16,)),
({'inputs': [('in0', []), ('in1', [1]), ('in2', [1, 2]), ('in3', [1000, 'dyn_ax1']), ('in4', ['dyn_ax1', 'dyn_ax2', 'dyn_ax3'])],
'outputs': [('out0', [], True), ('out1', [1], False), ('out2', [1, 'dyn_ax1', 3])]})
'outputs': [('out0', [], True), ('out1', [1], False), ('out2', [1, 'dyn_ax1', 3])]},
(torch.float,torch.uint8,torch.bool,torch.double,torch.half,), (torch.float,torch.float,torch.int64))
])
def testORTTrainerModelDescValidSchemas(test_input):
def testORTTrainerModelDescValidSchemas(test_data, input_dtype, output_dtype):
r''' Test different ways of using default values for incomplete input'''
md_val._ORTTrainerModelDesc(test_input)


@pytest.mark.parametrize("test_input,error_msg", [
# Validating model description from user
model_description = md_val._ORTTrainerModelDesc(test_data)
for idx, i_desc in enumerate(model_description.inputs):
assert isinstance(i_desc, model_description._InputDescription)
assert len(i_desc) == 2
assert test_data['inputs'][idx][0] == i_desc.name
assert test_data['inputs'][idx][1] == i_desc.shape
for idx, o_desc in enumerate(model_description.outputs):
assert isinstance(o_desc, model_description._OutputDescription)
assert len(o_desc) == 3
assert test_data['outputs'][idx][0] == o_desc.name
assert test_data['outputs'][idx][1] == o_desc.shape
is_loss = test_data['outputs'][idx][2] if len(test_data['outputs'][idx]) == 3 else False
assert is_loss == o_desc.is_loss

# Append type to inputs/outputs tuples
for idx, i_desc in enumerate(model_description.inputs):
model_description.add_type_to_input_description(idx, input_dtype[idx])
for idx, o_desc in enumerate(model_description.outputs):
model_description.add_type_to_output_description(idx, output_dtype[idx])

# Verify inputs/outputs tuples are replaced by the typed counterparts
for idx, i_desc in enumerate(model_description.inputs):
assert len(i_desc) == 3
assert isinstance(i_desc, model_description._InputDescriptionTyped)
assert input_dtype[idx] == i_desc.dtype
for idx, o_desc in enumerate(model_description.outputs):
assert len(o_desc) == 4
assert isinstance(o_desc, model_description._OutputDescriptionTyped)
assert output_dtype[idx] == o_desc.dtype


@pytest.mark.parametrize("test_data,error_msg", [
({'inputs': [(True, [])],
'outputs': [(True, [])]},
"Invalid model_desc: {'inputs': [{0: ['the first element of the tuple (aka name) must be a string']}], "
Expand All @@ -97,11 +128,17 @@ def testORTTrainerModelDescValidSchemas(test_input):
({'inputs': [('in1', [])],
'outputs': [('out1', [], True), ('out2', [], True)]},
"Invalid model_desc: {'outputs': [{1: ['only one is_loss can bet set to True']}]}"),
({'inputz': [('in1', [])],
'outputs': [('out1', [], True)]},
"Invalid model_desc: {'inputs': ['required field'], 'inputz': ['unknown field']}"),
({'inputs': [('in1', [])],
'outputz': [('out1', [], True)]},
"Invalid model_desc: {'outputs': ['required field'], 'outputz': ['unknown field']}"),
])
def testORTTrainerModelDescInvalidSchemas(test_input, error_msg):
def testORTTrainerModelDescInvalidSchemas(test_data, error_msg):
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
r''' Test different ways of using default values for incomplete input'''
with pytest.raises(ValueError) as e:
md_val._ORTTrainerModelDesc(test_input)
md_val._ORTTrainerModelDesc(test_data)
assert str(e.value) == error_msg


Expand Down