Skip to content

Commit

Permalink
Fix bugs and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Thiago Crepaldi committed Jun 18, 2020
1 parent 1104116 commit 3fb07ab
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 8 deletions.
15 changes: 14 additions & 1 deletion orttraining/orttraining/python/training/optim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, name, hyper_parameters, param_groups=[]):
"'name' must be one of 'AdamOptimizer', 'LambOptimizer' or 'SGDOptimizer'"
assert isinstance(hyper_parameters, dict), "'hyper_parameters' must be a dict"
assert 'lr' in hyper_parameters, "'hyper_parameters' must contain a {'lr' : positive number} entry"
assert hyper_parameters['lr'] >= 0, "lr must be a positive number"
assert isinstance(hyper_parameters['lr'], float) and hyper_parameters['lr'] >= 0, "lr must be a positive number"
assert isinstance(param_groups, list), "'param_groups' must be a list"
for group in param_groups:
assert isinstance(group, dict) and len(group) > 1 and 'params' in group, \
Expand Down Expand Up @@ -108,6 +108,12 @@ def __init__(self, param_groups=[], lr=0.001, alpha=0.9, beta=0.999, lambda_coef
'do_bias_correction' : do_bias_correction,
'weight_decay_mode' : weight_decay_mode}
super().__init__(name='AdamOptimizer', hyper_parameters=hyper_parameters, param_groups=param_groups)
self.alpha = alpha
self.beta = beta
self.lambda_coef = lambda_coef
self.epsilon = epsilon
self.do_bias_correction = do_bias_correction
self.weight_decay_mode = weight_decay_mode


class Lamb(_OptimizerConfig):
Expand Down Expand Up @@ -148,3 +154,10 @@ def __init__(self, param_groups=[], lr=0.001, alpha=0.9, beta=0.999, lambda_coef
'epsilon' : epsilon,
'do_bias_correction' : do_bias_correction}
super().__init__(name='LambOptimizer', hyper_parameters=hyper_parameters, param_groups=param_groups)
self.alpha = alpha
self.beta = beta
self.lambda_coef = lambda_coef
self.ratio_min = ratio_min
self.ratio_max = ratio_max
self.epsilon = epsilon
self.do_bias_correction = do_bias_correction
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@pytest.mark.parametrize("test_input", [
({}),
({'batch': {},
'cuda': {},
'device': {},
'distributed': {},
'mixed_precision': {},
'utils': {},
Expand All @@ -25,8 +25,8 @@ def testDefaultValues(test_input):
'batch': {
'gradient_accumulation_steps': 0
},
'cuda': {
'device': None,
'device': {
'id': None,
'mem_limit': 0
},
'distributed': {
Expand Down Expand Up @@ -64,6 +64,7 @@ def testInvalidMixedPrecisionEnabledSchema():
{'mixed_precision': {'enabled': 1}})
assert actual_values.mixed_precision[0].enabled[0] == expected_msg


def testTrainStepInfo():
'''Test valid initializations of TrainStepInfo'''

Expand Down Expand Up @@ -100,9 +101,74 @@ def testTrainStepInfoInvalidAllFinite(test_input):
('SGDOptimizer')
])
def testOptimizerConfigs(optim_name):
'''Test initialization of _OptimizerConfig and its extensions'''
hyper_parameters={'lr':0.001}
cfg = optim.config._OptimizerConfig(name=optim_name, hyper_parameters=hyper_parameters, param_groups=[])
'''Test initialization of _OptimizerConfig'''
hyper_parameters = {'lr': 0.001, 'alpha': 0.9}
param_groups = [{'params': ['fc1.weight', 'fc2.weight'], 'alpha':.0}]
cfg = optim.config._OptimizerConfig(
name=optim_name, hyper_parameters=hyper_parameters, param_groups=param_groups)

assert cfg.name == optim_name
rtol = 1e-03
assert_allclose(hyper_parameters['lr'], cfg.lr, rtol=rtol, err_msg="loss mismatch")
assert_allclose(hyper_parameters['lr'],
cfg.lr, rtol=rtol, err_msg="lr mismatch")


@pytest.mark.parametrize("optim_name,hyper_parameters,param_groups", [
('AdamOptimizer', {'lr': -1}, []), # invalid lr
('FooOptimizer', {'lr': 0.001}, []), # invalid name
('SGDOptimizer', [], []), # invalid type(hyper_parameters)
(optim.config.Adam, {'lr': 0.003}, []), # invalid type(name)
('AdamOptimizer', {'lr': None}, []), # missing 'lr' hyper parameter
('SGDOptimizer', {'lr': 0.004}, {}), # invalid type(param_groups)
('AdamOptimizer', {'lr': 0.005, 'alpha': 2}, [[]]), # invalid type(param_groups[i])
('AdamOptimizer', {'lr': 0.005, 'alpha': 2}, [{'alpha': 1}]), # missing 'params' at 'param_groups'
('AdamOptimizer', {'lr': 0.005}, [{'params': 'param1', 'alpha': 1}]), # missing 'alpha' at 'hyper_parameters'
])
def testOptimizerConfigsInvalidInputs(optim_name, hyper_parameters, param_groups):
'''Test invalid initialization of _OptimizerConfig'''

with pytest.raises(AssertionError):
optim.config._OptimizerConfig(
name=optim_name, hyper_parameters=hyper_parameters, param_groups=param_groups)


def testSGD():
'''Test initialization of SGD'''
cfg = optim.config.SGD()
assert cfg.name == 'SGDOptimizer'

rtol = 1e-05
assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch")

cfg = optim.config.SGD(lr=0.002)
assert_allclose(0.002, cfg.lr, rtol=rtol, err_msg="lr mismatch")


def testAdam():
'''Test initialization of Adam'''
cfg = optim.config.Adam()
assert cfg.name == 'AdamOptimizer'

rtol = 1e-05
assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch")
assert_allclose(0.9, cfg.alpha, rtol=rtol, err_msg="alpha mismatch")
assert_allclose(0.999, cfg.beta, rtol=rtol, err_msg="beta mismatch")
assert_allclose(0.0, cfg.lambda_coef, rtol=rtol, err_msg="lambda_coef mismatch")
assert_allclose(1e-8, cfg.epsilon, rtol=rtol, err_msg="epsilon mismatch")
assert cfg.do_bias_correction == True, "lambda_coef mismatch"
assert cfg.weight_decay_mode == True, "weight_decay_mode mismatch"


def testLamb():
'''Test initialization of Lamb'''
cfg = optim.config.Lamb()
assert cfg.name == 'LambOptimizer'
rtol = 1e-05
assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch")
assert_allclose(0.9, cfg.alpha, rtol=rtol, err_msg="alpha mismatch")
assert_allclose(0.999, cfg.beta, rtol=rtol, err_msg="beta mismatch")
assert_allclose(0.0, cfg.lambda_coef, rtol=rtol, err_msg="lambda_coef mismatch")
assert cfg.ratio_min == float('-inf'), "ratio_min mismatch"
assert cfg.ratio_max == float('inf'), "ratio_max mismatch"
assert_allclose(1e-6, cfg.epsilon, rtol=rtol, err_msg="epsilon mismatch")
assert cfg.do_bias_correction == True, "lambda_coef mismatch"

0 comments on commit 3fb07ab

Please sign in to comment.