diff --git a/orttraining/orttraining/python/training/model_desc_validation.py b/orttraining/orttraining/python/training/model_desc_validation.py index fa377646212fe..df3c6d5149025 100644 --- a/orttraining/orttraining/python/training/model_desc_validation.py +++ b/orttraining/orttraining/python/training/model_desc_validation.py @@ -1,5 +1,6 @@ import cerberus - +from collections import namedtuple +import torch from ._utils import static_vars @@ -24,18 +25,36 @@ def __init__(self, model_desc): if self._validated is None: raise ValueError(f'Invalid model_desc: {validator.errors}') + # Normalize inputs to a list of namedtuple(name, shape) + self._InputDescription = namedtuple('InputDescription', ['name', 'shape']) + self._InputDescriptionTyped = namedtuple('InputDescriptionTyped', ['name', 'shape', 'dtype']) + for idx, input in enumerate(self._validated['inputs']): + self._validated['inputs'][idx] = self._InputDescription(*input) + + # Normalize outputs to a list of namedtuple(name, shape, is_loss) + self._OutputDescription = namedtuple('OutputDescription', ['name', 'shape', 'is_loss']) + self._OutputDescriptionTyped = namedtuple('OutputDescriptionTyped', ['name', 'shape', 'is_loss', 'dtype']) + for idx, output in enumerate(self._validated['outputs']): + if len(output) == 2: + self._validated['outputs'][idx] = self._OutputDescription(*output, False) + else: + self._validated['outputs'][idx] = self._OutputDescription(*output) + # Convert dict in object for k, v in self._validated.items(): setattr(self, k, self._wrap(v)) # Keep this in the last line # After this point, this class becomes immutable + # NOTE: The embedded lists are still muttable self._initialized = True def __repr__(self): return '{%s}' % str(', '.join("'%s': %s" % (k, repr(v)) for (k, v) in self.__dict__.items() - if k not in ['_main_class_name', '_original', '_validated', '_initialized'])) + if k not in ['_main_class_name', '_original', '_validated', + '_InputDescription', '_InputDescriptionTyped', + '_OutputDescription', '_OutputDescriptionTyped'])) def __setattr__(self, k, v): if hasattr(self, '_initialized'): @@ -43,10 +62,32 @@ def __setattr__(self, k, v): return super().__setattr__(k, v) def _wrap(self, v): - if isinstance(v, (tuple, list, set, frozenset)): + if isinstance(v, (list)): return type(v)([self._wrap(v) for v in v]) - else: + elif isinstance(v, (self._InputDescription, self._InputDescriptionTyped, + self._OutputDescription, self._OutputDescriptionTyped)): + return v + elif isinstance(v, (tuple)): + return type(v)([self._wrap(v) for v in v]) + elif isinstance(v, (dict, int, float, bool, str)): return _ORTTrainerModelDescInternal(self._main_class_name, v) if isinstance(v, dict) else v + else: + raise ValueError("Unsupported type for model_desc." + "Only int, float, bool, str, list, tuple and dict are supported") + + def add_type_to_input_description(self, index, dtype): + assert isinstance(index, int) and index >= 0,\ + "input 'index' must be a positive int" + assert isinstance(dtype, torch.dtype),\ + "input 'dtype' must be a torch.dtype type" + self.inputs[index] = self._InputDescriptionTyped(*self.inputs[index], dtype) + + def add_type_to_output_description(self, index, dtype): + assert isinstance(index, int) and index >= 0,\ + "output 'index' must be a positive int" + assert isinstance(dtype, torch.dtype),\ + "output 'dtype' must be a torch.dtype type" + self.outputs[index] = self._OutputDescriptionTyped(*self.outputs[index], dtype) class _ORTTrainerModelDescInternal(_ORTTrainerModelDesc): @@ -65,6 +106,7 @@ def __init__(self, main_class_name, model_desc): # Keep this in the last line # After this point, this class becomes immutable + # NOTE: The embedded lists are still muttable self._initialized = True diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index d57c4271cd901..661440480a525 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -6,6 +6,7 @@ from . import optim from .model_desc_validation import _ORTTrainerModelDesc + class TrainStepInfo(object): r"""Private class used to store runtime information from current train step. @@ -141,7 +142,10 @@ def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None): self.model_desc = _ORTTrainerModelDesc(model_desc) self.optim_config = optim_config - self.options = ORTTrainerOptions(options) + if options: + self.options = ORTTrainerOptions(options) + else: + self.options = ORTTrainerOptions() def eval_step(self, *input, **kwargs): r"""Evaluation step method diff --git a/orttraining/orttraining/python/training/orttrainer_options.py b/orttraining/orttraining/python/training/orttrainer_options.py index 83d2ac3216a7e..e1ebea9613e06 100644 --- a/orttraining/orttraining/python/training/orttrainer_options.py +++ b/orttraining/orttraining/python/training/orttrainer_options.py @@ -222,7 +222,7 @@ class ORTTrainerOptions(object): fp16_enabled = opts.mixed_precision.enabled """ - def __init__(self, options): + def __init__(self, options={}): # Keep a copy of original input for debug self._original_opts = dict(options) diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py index 52ee34ec1039e..882a206933094 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py @@ -65,20 +65,51 @@ def testORTTrainerOptionsInvalidMixedPrecisionEnabledSchema(): assert str(e.value) == expected_msg -@pytest.mark.parametrize("test_input", [ +@pytest.mark.parametrize("input_dict,input_dtype,output_dtype", [ ({'inputs': [('in0', [])], - 'outputs': [('out0', []), ('out1', [])]}), + 'outputs': [('out0', []), ('out1', [])]},(torch.int,),(torch.float,torch.int32,)), ({'inputs': [('in0', ['batch', 2, 3])], - 'outputs': [('out0', [], True)]}), + 'outputs': [('out0', [], True)]}, (torch.int8,), (torch.int16,)), ({'inputs': [('in0', []), ('in1', [1]), ('in2', [1, 2]), ('in3', [1000, 'dyn_ax1']), ('in4', ['dyn_ax1', 'dyn_ax2', 'dyn_ax3'])], - 'outputs': [('out0', [], True), ('out1', [1], False), ('out2', [1, 'dyn_ax1', 3])]}) + 'outputs': [('out0', [], True), ('out1', [1], False), ('out2', [1, 'dyn_ax1', 3])]}, + (torch.float,torch.uint8,torch.bool,torch.double,torch.half,), (torch.float,torch.float,torch.int64)) ]) -def testORTTrainerModelDescValidSchemas(test_input): +def testORTTrainerModelDescValidSchemas(input_dict, input_dtype, output_dtype): r''' Test different ways of using default values for incomplete input''' - md_val._ORTTrainerModelDesc(test_input) - -@pytest.mark.parametrize("test_input,error_msg", [ + # Validating model description from user + model_description = md_val._ORTTrainerModelDesc(input_dict) + for idx, i_desc in enumerate(model_description.inputs): + assert isinstance(i_desc, model_description._InputDescription) + assert len(i_desc) == 2 + assert input_dict['inputs'][idx][0] == i_desc.name + assert input_dict['inputs'][idx][1] == i_desc.shape + for idx, o_desc in enumerate(model_description.outputs): + assert isinstance(o_desc, model_description._OutputDescription) + assert len(o_desc) == 3 + assert input_dict['outputs'][idx][0] == o_desc.name + assert input_dict['outputs'][idx][1] == o_desc.shape + is_loss = input_dict['outputs'][idx][2] if len(input_dict['outputs'][idx]) == 3 else False + assert is_loss == o_desc.is_loss + + # Append type to inputs/outputs tuples + for idx, i_desc in enumerate(model_description.inputs): + model_description.add_type_to_input_description(idx, input_dtype[idx]) + for idx, o_desc in enumerate(model_description.outputs): + model_description.add_type_to_output_description(idx, output_dtype[idx]) + + # Verify inputs/outputs tuples are replaced by the typed counterparts + for idx, i_desc in enumerate(model_description.inputs): + assert len(i_desc) == 3 + assert isinstance(i_desc, model_description._InputDescriptionTyped) + assert input_dtype[idx] == i_desc.dtype + for idx, o_desc in enumerate(model_description.outputs): + assert len(o_desc) == 4 + assert isinstance(o_desc, model_description._OutputDescriptionTyped) + assert output_dtype[idx] == o_desc.dtype + + +@pytest.mark.parametrize("input_dict,error_msg", [ ({'inputs': [(True, [])], 'outputs': [(True, [])]}, "Invalid model_desc: {'inputs': [{0: ['the first element of the tuple (aka name) must be a string']}], " @@ -97,11 +128,17 @@ def testORTTrainerModelDescValidSchemas(test_input): ({'inputs': [('in1', [])], 'outputs': [('out1', [], True), ('out2', [], True)]}, "Invalid model_desc: {'outputs': [{1: ['only one is_loss can bet set to True']}]}"), + ({'inputz': [('in1', [])], + 'outputs': [('out1', [], True)]}, + "Invalid model_desc: {'inputs': ['required field'], 'inputz': ['unknown field']}"), + ({'inputs': [('in1', [])], + 'outputz': [('out1', [], True)]}, + "Invalid model_desc: {'outputs': ['required field'], 'outputz': ['unknown field']}"), ]) -def testORTTrainerModelDescInvalidSchemas(test_input, error_msg): +def testORTTrainerModelDescInvalidSchemas(input_dict, error_msg): r''' Test different ways of using default values for incomplete input''' with pytest.raises(ValueError) as e: - md_val._ORTTrainerModelDesc(test_input) + md_val._ORTTrainerModelDesc(input_dict) assert str(e.value) == error_msg