From 6b513899b4da2fc9c53139c5e89b0493c5c426bb Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Thu, 23 Jul 2020 14:56:14 -0700 Subject: [PATCH 1/2] Update ModelDescription and minor fix on ORTTrainer/ORTTrainerOptions This PR keeps the public API intact, but changes how model description is stored on the backend Currently, users creates a dict with two lists of tuples. One list called 'inputs' and each tuple has the following format tuple(name, shape). The second list is called 'outputs' and each tuple can be either tuple(name, shape) or tuple(name, shape, is_loss). With this PR, when this dict is passed in to ORTTrainer, it is fully validated as usual. However, tuples are internally replaced by namedtuples and all output tuples will have tuple(name, shape, is_loss) format instead of is_loss being optionally present. Additionally to that normalization in the internal representation (which eases coding), two internal methods were created to replace a namedtuple(name, shape) to namedtuple(name, shape, dtype) or namedtuple(name, shape, is_loss, dtype) dependeing whether the tuple is an input or output. This is necessary as ORTTRainer finds out data types of each input/output during model export to onnx. Finally, a minor fix was done on ORTTrainer. It could initialize ORTTrainerOptions incorrectly when options=None --- .../python/training/model_desc_validation.py | 50 ++++++++++++++-- .../orttraining/python/training/orttrainer.py | 6 +- .../python/training/orttrainer_options.py | 2 +- .../orttraining_test_orttrainer_frontend.py | 57 +++++++++++++++---- 4 files changed, 99 insertions(+), 16 deletions(-) diff --git a/orttraining/orttraining/python/training/model_desc_validation.py b/orttraining/orttraining/python/training/model_desc_validation.py index fa377646212fe..df3c6d5149025 100644 --- a/orttraining/orttraining/python/training/model_desc_validation.py +++ b/orttraining/orttraining/python/training/model_desc_validation.py @@ -1,5 +1,6 @@ import cerberus - +from collections import namedtuple +import torch from ._utils import static_vars @@ -24,18 +25,36 @@ def __init__(self, model_desc): if self._validated is None: raise ValueError(f'Invalid model_desc: {validator.errors}') + # Normalize inputs to a list of namedtuple(name, shape) + self._InputDescription = namedtuple('InputDescription', ['name', 'shape']) + self._InputDescriptionTyped = namedtuple('InputDescriptionTyped', ['name', 'shape', 'dtype']) + for idx, input in enumerate(self._validated['inputs']): + self._validated['inputs'][idx] = self._InputDescription(*input) + + # Normalize outputs to a list of namedtuple(name, shape, is_loss) + self._OutputDescription = namedtuple('OutputDescription', ['name', 'shape', 'is_loss']) + self._OutputDescriptionTyped = namedtuple('OutputDescriptionTyped', ['name', 'shape', 'is_loss', 'dtype']) + for idx, output in enumerate(self._validated['outputs']): + if len(output) == 2: + self._validated['outputs'][idx] = self._OutputDescription(*output, False) + else: + self._validated['outputs'][idx] = self._OutputDescription(*output) + # Convert dict in object for k, v in self._validated.items(): setattr(self, k, self._wrap(v)) # Keep this in the last line # After this point, this class becomes immutable + # NOTE: The embedded lists are still muttable self._initialized = True def __repr__(self): return '{%s}' % str(', '.join("'%s': %s" % (k, repr(v)) for (k, v) in self.__dict__.items() - if k not in ['_main_class_name', '_original', '_validated', '_initialized'])) + if k not in ['_main_class_name', '_original', '_validated', + '_InputDescription', '_InputDescriptionTyped', + '_OutputDescription', '_OutputDescriptionTyped'])) def __setattr__(self, k, v): if hasattr(self, '_initialized'): @@ -43,10 +62,32 @@ def __setattr__(self, k, v): return super().__setattr__(k, v) def _wrap(self, v): - if isinstance(v, (tuple, list, set, frozenset)): + if isinstance(v, (list)): return type(v)([self._wrap(v) for v in v]) - else: + elif isinstance(v, (self._InputDescription, self._InputDescriptionTyped, + self._OutputDescription, self._OutputDescriptionTyped)): + return v + elif isinstance(v, (tuple)): + return type(v)([self._wrap(v) for v in v]) + elif isinstance(v, (dict, int, float, bool, str)): return _ORTTrainerModelDescInternal(self._main_class_name, v) if isinstance(v, dict) else v + else: + raise ValueError("Unsupported type for model_desc." + "Only int, float, bool, str, list, tuple and dict are supported") + + def add_type_to_input_description(self, index, dtype): + assert isinstance(index, int) and index >= 0,\ + "input 'index' must be a positive int" + assert isinstance(dtype, torch.dtype),\ + "input 'dtype' must be a torch.dtype type" + self.inputs[index] = self._InputDescriptionTyped(*self.inputs[index], dtype) + + def add_type_to_output_description(self, index, dtype): + assert isinstance(index, int) and index >= 0,\ + "output 'index' must be a positive int" + assert isinstance(dtype, torch.dtype),\ + "output 'dtype' must be a torch.dtype type" + self.outputs[index] = self._OutputDescriptionTyped(*self.outputs[index], dtype) class _ORTTrainerModelDescInternal(_ORTTrainerModelDesc): @@ -65,6 +106,7 @@ def __init__(self, main_class_name, model_desc): # Keep this in the last line # After this point, this class becomes immutable + # NOTE: The embedded lists are still muttable self._initialized = True diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index d57c4271cd901..661440480a525 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -6,6 +6,7 @@ from . import optim from .model_desc_validation import _ORTTrainerModelDesc + class TrainStepInfo(object): r"""Private class used to store runtime information from current train step. @@ -141,7 +142,10 @@ def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None): self.model_desc = _ORTTrainerModelDesc(model_desc) self.optim_config = optim_config - self.options = ORTTrainerOptions(options) + if options: + self.options = ORTTrainerOptions(options) + else: + self.options = ORTTrainerOptions() def eval_step(self, *input, **kwargs): r"""Evaluation step method diff --git a/orttraining/orttraining/python/training/orttrainer_options.py b/orttraining/orttraining/python/training/orttrainer_options.py index 83d2ac3216a7e..e1ebea9613e06 100644 --- a/orttraining/orttraining/python/training/orttrainer_options.py +++ b/orttraining/orttraining/python/training/orttrainer_options.py @@ -222,7 +222,7 @@ class ORTTrainerOptions(object): fp16_enabled = opts.mixed_precision.enabled """ - def __init__(self, options): + def __init__(self, options={}): # Keep a copy of original input for debug self._original_opts = dict(options) diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py index 52ee34ec1039e..755a5b47458b2 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py @@ -65,20 +65,51 @@ def testORTTrainerOptionsInvalidMixedPrecisionEnabledSchema(): assert str(e.value) == expected_msg -@pytest.mark.parametrize("test_input", [ +@pytest.mark.parametrize("test_data,input_dtype,output_dtype", [ ({'inputs': [('in0', [])], - 'outputs': [('out0', []), ('out1', [])]}), + 'outputs': [('out0', []), ('out1', [])]},(torch.int,),(torch.float,torch.int32,)), ({'inputs': [('in0', ['batch', 2, 3])], - 'outputs': [('out0', [], True)]}), + 'outputs': [('out0', [], True)]}, (torch.int8,), (torch.int16,)), ({'inputs': [('in0', []), ('in1', [1]), ('in2', [1, 2]), ('in3', [1000, 'dyn_ax1']), ('in4', ['dyn_ax1', 'dyn_ax2', 'dyn_ax3'])], - 'outputs': [('out0', [], True), ('out1', [1], False), ('out2', [1, 'dyn_ax1', 3])]}) + 'outputs': [('out0', [], True), ('out1', [1], False), ('out2', [1, 'dyn_ax1', 3])]}, + (torch.float,torch.uint8,torch.bool,torch.double,torch.half,), (torch.float,torch.float,torch.int64)) ]) -def testORTTrainerModelDescValidSchemas(test_input): +def testORTTrainerModelDescValidSchemas(test_data, input_dtype, output_dtype): r''' Test different ways of using default values for incomplete input''' - md_val._ORTTrainerModelDesc(test_input) - -@pytest.mark.parametrize("test_input,error_msg", [ + # Validating model description from user + model_description = md_val._ORTTrainerModelDesc(test_data) + for idx, i_desc in enumerate(model_description.inputs): + assert isinstance(i_desc, model_description._InputDescription) + assert len(i_desc) == 2 + assert test_data['inputs'][idx][0] == i_desc.name + assert test_data['inputs'][idx][1] == i_desc.shape + for idx, o_desc in enumerate(model_description.outputs): + assert isinstance(o_desc, model_description._OutputDescription) + assert len(o_desc) == 3 + assert test_data['outputs'][idx][0] == o_desc.name + assert test_data['outputs'][idx][1] == o_desc.shape + is_loss = test_data['outputs'][idx][2] if len(test_data['outputs'][idx]) == 3 else False + assert is_loss == o_desc.is_loss + + # Append type to inputs/outputs tuples + for idx, i_desc in enumerate(model_description.inputs): + model_description.add_type_to_input_description(idx, input_dtype[idx]) + for idx, o_desc in enumerate(model_description.outputs): + model_description.add_type_to_output_description(idx, output_dtype[idx]) + + # Verify inputs/outputs tuples are replaced by the typed counterparts + for idx, i_desc in enumerate(model_description.inputs): + assert len(i_desc) == 3 + assert isinstance(i_desc, model_description._InputDescriptionTyped) + assert input_dtype[idx] == i_desc.dtype + for idx, o_desc in enumerate(model_description.outputs): + assert len(o_desc) == 4 + assert isinstance(o_desc, model_description._OutputDescriptionTyped) + assert output_dtype[idx] == o_desc.dtype + + +@pytest.mark.parametrize("test_data,error_msg", [ ({'inputs': [(True, [])], 'outputs': [(True, [])]}, "Invalid model_desc: {'inputs': [{0: ['the first element of the tuple (aka name) must be a string']}], " @@ -97,11 +128,17 @@ def testORTTrainerModelDescValidSchemas(test_input): ({'inputs': [('in1', [])], 'outputs': [('out1', [], True), ('out2', [], True)]}, "Invalid model_desc: {'outputs': [{1: ['only one is_loss can bet set to True']}]}"), + ({'inputz': [('in1', [])], + 'outputs': [('out1', [], True)]}, + "Invalid model_desc: {'inputs': ['required field'], 'inputz': ['unknown field']}"), + ({'inputs': [('in1', [])], + 'outputz': [('out1', [], True)]}, + "Invalid model_desc: {'outputs': ['required field'], 'outputz': ['unknown field']}"), ]) -def testORTTrainerModelDescInvalidSchemas(test_input, error_msg): +def testORTTrainerModelDescInvalidSchemas(test_data, error_msg): r''' Test different ways of using default values for incomplete input''' with pytest.raises(ValueError) as e: - md_val._ORTTrainerModelDesc(test_input) + md_val._ORTTrainerModelDesc(test_data) assert str(e.value) == error_msg From f4c985b199d48e00f7d36ef46250f83955cf1392 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Fri, 24 Jul 2020 10:34:16 -0700 Subject: [PATCH 2/2] Rename input name for test --- .../orttraining_test_orttrainer_frontend.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py index 755a5b47458b2..882a206933094 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py @@ -65,7 +65,7 @@ def testORTTrainerOptionsInvalidMixedPrecisionEnabledSchema(): assert str(e.value) == expected_msg -@pytest.mark.parametrize("test_data,input_dtype,output_dtype", [ +@pytest.mark.parametrize("input_dict,input_dtype,output_dtype", [ ({'inputs': [('in0', [])], 'outputs': [('out0', []), ('out1', [])]},(torch.int,),(torch.float,torch.int32,)), ({'inputs': [('in0', ['batch', 2, 3])], @@ -74,22 +74,22 @@ def testORTTrainerOptionsInvalidMixedPrecisionEnabledSchema(): 'outputs': [('out0', [], True), ('out1', [1], False), ('out2', [1, 'dyn_ax1', 3])]}, (torch.float,torch.uint8,torch.bool,torch.double,torch.half,), (torch.float,torch.float,torch.int64)) ]) -def testORTTrainerModelDescValidSchemas(test_data, input_dtype, output_dtype): +def testORTTrainerModelDescValidSchemas(input_dict, input_dtype, output_dtype): r''' Test different ways of using default values for incomplete input''' # Validating model description from user - model_description = md_val._ORTTrainerModelDesc(test_data) + model_description = md_val._ORTTrainerModelDesc(input_dict) for idx, i_desc in enumerate(model_description.inputs): assert isinstance(i_desc, model_description._InputDescription) assert len(i_desc) == 2 - assert test_data['inputs'][idx][0] == i_desc.name - assert test_data['inputs'][idx][1] == i_desc.shape + assert input_dict['inputs'][idx][0] == i_desc.name + assert input_dict['inputs'][idx][1] == i_desc.shape for idx, o_desc in enumerate(model_description.outputs): assert isinstance(o_desc, model_description._OutputDescription) assert len(o_desc) == 3 - assert test_data['outputs'][idx][0] == o_desc.name - assert test_data['outputs'][idx][1] == o_desc.shape - is_loss = test_data['outputs'][idx][2] if len(test_data['outputs'][idx]) == 3 else False + assert input_dict['outputs'][idx][0] == o_desc.name + assert input_dict['outputs'][idx][1] == o_desc.shape + is_loss = input_dict['outputs'][idx][2] if len(input_dict['outputs'][idx]) == 3 else False assert is_loss == o_desc.is_loss # Append type to inputs/outputs tuples @@ -109,7 +109,7 @@ def testORTTrainerModelDescValidSchemas(test_data, input_dtype, output_dtype): assert output_dtype[idx] == o_desc.dtype -@pytest.mark.parametrize("test_data,error_msg", [ +@pytest.mark.parametrize("input_dict,error_msg", [ ({'inputs': [(True, [])], 'outputs': [(True, [])]}, "Invalid model_desc: {'inputs': [{0: ['the first element of the tuple (aka name) must be a string']}], " @@ -135,10 +135,10 @@ def testORTTrainerModelDescValidSchemas(test_data, input_dtype, output_dtype): 'outputz': [('out1', [], True)]}, "Invalid model_desc: {'outputs': ['required field'], 'outputz': ['unknown field']}"), ]) -def testORTTrainerModelDescInvalidSchemas(test_data, error_msg): +def testORTTrainerModelDescInvalidSchemas(input_dict, error_msg): r''' Test different ways of using default values for incomplete input''' with pytest.raises(ValueError) as e: - md_val._ORTTrainerModelDesc(test_data) + md_val._ORTTrainerModelDesc(input_dict) assert str(e.value) == error_msg