From 3fb07abaefa17fa824580be88a4e87a041787312 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Thu, 18 Jun 2020 15:39:57 -0700 Subject: [PATCH] Fix bugs and add tests --- .../python/training/optim/config.py | 15 +++- ...ttraining_test_pytorch_trainer_frontend.py | 80 +++++++++++++++++-- 2 files changed, 87 insertions(+), 8 deletions(-) diff --git a/orttraining/orttraining/python/training/optim/config.py b/orttraining/orttraining/python/training/optim/config.py index 05608f161962e..6e97ac38086d0 100644 --- a/orttraining/orttraining/python/training/optim/config.py +++ b/orttraining/orttraining/python/training/optim/config.py @@ -35,7 +35,7 @@ def __init__(self, name, hyper_parameters, param_groups=[]): "'name' must be one of 'AdamOptimizer', 'LambOptimizer' or 'SGDOptimizer'" assert isinstance(hyper_parameters, dict), "'hyper_parameters' must be a dict" assert 'lr' in hyper_parameters, "'hyper_parameters' must contain a {'lr' : positive number} entry" - assert hyper_parameters['lr'] >= 0, "lr must be a positive number" + assert isinstance(hyper_parameters['lr'], float) and hyper_parameters['lr'] >= 0, "lr must be a positive number" assert isinstance(param_groups, list), "'param_groups' must be a list" for group in param_groups: assert isinstance(group, dict) and len(group) > 1 and 'params' in group, \ @@ -108,6 +108,12 @@ def __init__(self, param_groups=[], lr=0.001, alpha=0.9, beta=0.999, lambda_coef 'do_bias_correction' : do_bias_correction, 'weight_decay_mode' : weight_decay_mode} super().__init__(name='AdamOptimizer', hyper_parameters=hyper_parameters, param_groups=param_groups) + self.alpha = alpha + self.beta = beta + self.lambda_coef = lambda_coef + self.epsilon = epsilon + self.do_bias_correction = do_bias_correction + self.weight_decay_mode = weight_decay_mode class Lamb(_OptimizerConfig): @@ -148,3 +154,10 @@ def __init__(self, param_groups=[], lr=0.001, alpha=0.9, beta=0.999, lambda_coef 'epsilon' : epsilon, 'do_bias_correction' : do_bias_correction} super().__init__(name='LambOptimizer', hyper_parameters=hyper_parameters, param_groups=param_groups) + self.alpha = alpha + self.beta = beta + self.lambda_coef = lambda_coef + self.ratio_min = ratio_min + self.ratio_max = ratio_max + self.epsilon = epsilon + self.do_bias_correction = do_bias_correction diff --git a/orttraining/orttraining/test/python/orttraining_test_pytorch_trainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_pytorch_trainer_frontend.py index 0e5c0b34774d5..d3a249eceef63 100644 --- a/orttraining/orttraining/test/python/orttraining_test_pytorch_trainer_frontend.py +++ b/orttraining/orttraining/test/python/orttraining_test_pytorch_trainer_frontend.py @@ -12,7 +12,7 @@ @pytest.mark.parametrize("test_input", [ ({}), ({'batch': {}, - 'cuda': {}, + 'device': {}, 'distributed': {}, 'mixed_precision': {}, 'utils': {}, @@ -25,8 +25,8 @@ def testDefaultValues(test_input): 'batch': { 'gradient_accumulation_steps': 0 }, - 'cuda': { - 'device': None, + 'device': { + 'id': None, 'mem_limit': 0 }, 'distributed': { @@ -64,6 +64,7 @@ def testInvalidMixedPrecisionEnabledSchema(): {'mixed_precision': {'enabled': 1}}) assert actual_values.mixed_precision[0].enabled[0] == expected_msg + def testTrainStepInfo(): '''Test valid initializations of TrainStepInfo''' @@ -100,9 +101,74 @@ def testTrainStepInfoInvalidAllFinite(test_input): ('SGDOptimizer') ]) def testOptimizerConfigs(optim_name): - '''Test initialization of _OptimizerConfig and its extensions''' - hyper_parameters={'lr':0.001} - cfg = optim.config._OptimizerConfig(name=optim_name, hyper_parameters=hyper_parameters, param_groups=[]) + '''Test initialization of _OptimizerConfig''' + hyper_parameters = {'lr': 0.001, 'alpha': 0.9} + param_groups = [{'params': ['fc1.weight', 'fc2.weight'], 'alpha':.0}] + cfg = optim.config._OptimizerConfig( + name=optim_name, hyper_parameters=hyper_parameters, param_groups=param_groups) + assert cfg.name == optim_name rtol = 1e-03 - assert_allclose(hyper_parameters['lr'], cfg.lr, rtol=rtol, err_msg="loss mismatch") + assert_allclose(hyper_parameters['lr'], + cfg.lr, rtol=rtol, err_msg="lr mismatch") + + +@pytest.mark.parametrize("optim_name,hyper_parameters,param_groups", [ + ('AdamOptimizer', {'lr': -1}, []), # invalid lr + ('FooOptimizer', {'lr': 0.001}, []), # invalid name + ('SGDOptimizer', [], []), # invalid type(hyper_parameters) + (optim.config.Adam, {'lr': 0.003}, []), # invalid type(name) + ('AdamOptimizer', {'lr': None}, []), # missing 'lr' hyper parameter + ('SGDOptimizer', {'lr': 0.004}, {}), # invalid type(param_groups) + ('AdamOptimizer', {'lr': 0.005, 'alpha': 2}, [[]]), # invalid type(param_groups[i]) + ('AdamOptimizer', {'lr': 0.005, 'alpha': 2}, [{'alpha': 1}]), # missing 'params' at 'param_groups' + ('AdamOptimizer', {'lr': 0.005}, [{'params': 'param1', 'alpha': 1}]), # missing 'alpha' at 'hyper_parameters' +]) +def testOptimizerConfigsInvalidInputs(optim_name, hyper_parameters, param_groups): + '''Test invalid initialization of _OptimizerConfig''' + + with pytest.raises(AssertionError): + optim.config._OptimizerConfig( + name=optim_name, hyper_parameters=hyper_parameters, param_groups=param_groups) + + +def testSGD(): + '''Test initialization of SGD''' + cfg = optim.config.SGD() + assert cfg.name == 'SGDOptimizer' + + rtol = 1e-05 + assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") + + cfg = optim.config.SGD(lr=0.002) + assert_allclose(0.002, cfg.lr, rtol=rtol, err_msg="lr mismatch") + + +def testAdam(): + '''Test initialization of Adam''' + cfg = optim.config.Adam() + assert cfg.name == 'AdamOptimizer' + + rtol = 1e-05 + assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") + assert_allclose(0.9, cfg.alpha, rtol=rtol, err_msg="alpha mismatch") + assert_allclose(0.999, cfg.beta, rtol=rtol, err_msg="beta mismatch") + assert_allclose(0.0, cfg.lambda_coef, rtol=rtol, err_msg="lambda_coef mismatch") + assert_allclose(1e-8, cfg.epsilon, rtol=rtol, err_msg="epsilon mismatch") + assert cfg.do_bias_correction == True, "lambda_coef mismatch" + assert cfg.weight_decay_mode == True, "weight_decay_mode mismatch" + + +def testLamb(): + '''Test initialization of Lamb''' + cfg = optim.config.Lamb() + assert cfg.name == 'LambOptimizer' + rtol = 1e-05 + assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") + assert_allclose(0.9, cfg.alpha, rtol=rtol, err_msg="alpha mismatch") + assert_allclose(0.999, cfg.beta, rtol=rtol, err_msg="beta mismatch") + assert_allclose(0.0, cfg.lambda_coef, rtol=rtol, err_msg="lambda_coef mismatch") + assert cfg.ratio_min == float('-inf'), "ratio_min mismatch" + assert cfg.ratio_max == float('inf'), "ratio_max mismatch" + assert_allclose(1e-6, cfg.epsilon, rtol=rtol, err_msg="epsilon mismatch") + assert cfg.do_bias_correction == True, "lambda_coef mismatch"