From 912062b8293f89e0af9ff8b90faf312089ad905b Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Fri, 17 Jul 2020 12:41:31 -0700 Subject: [PATCH] Add LRScheduler implementation (#4357) --- .../python/training/optim/__init__.py | 2 + .../python/training/optim/lr_scheduler.py | 260 +++++++++++++++++- .../orttraining_test_orttrainer_frontend.py | 60 +++- 3 files changed, 315 insertions(+), 7 deletions(-) diff --git a/orttraining/orttraining/python/training/optim/__init__.py b/orttraining/orttraining/python/training/optim/__init__.py index 7c8339c308008..d0541a609b25d 100644 --- a/orttraining/orttraining/python/training/optim/__init__.py +++ b/orttraining/orttraining/python/training/optim/__init__.py @@ -1 +1,3 @@ from .config import _OptimizerConfig, AdamConfig, LambConfig, SGDConfig +from .lr_scheduler import _LRScheduler, ConstantWarmupLRScheduler, CosineWarmupLRScheduler,\ + LinearWarmupLRScheduler, PolyWarmupLRScheduler diff --git a/orttraining/orttraining/python/training/optim/lr_scheduler.py b/orttraining/orttraining/python/training/optim/lr_scheduler.py index 9418016d456f1..03cbdf6074331 100644 --- a/orttraining/orttraining/python/training/optim/lr_scheduler.py +++ b/orttraining/orttraining/python/training/optim/lr_scheduler.py @@ -1,2 +1,260 @@ +import math + + class _LRScheduler(object): - pass \ No newline at end of file + r"""Base class for implementing custom learning rate schedulers + + Schedulers can be either stateful or stateless. + Stateless implementation can only rely on information available at + :py:class:`.TrainStepInfo`. + Stateful implementation, on the other hand, can store additional parameters + by overriding the constructor. + + In both cases, once the scheduler is configured, no user code is needed + to update learning rate during each train step. + + NOTE: Current implementation doesn't support 'lr' within :py:attr:`param_groups` entries. + """ + + def __init__(self): + self._last_lr = [] + + def get_lr(self, train_step_info): + r"""Returns a list of learning rate + + Args: + train_step_info (:py:class:`.TrainStepInfo`): runtime info for current training step + + Returns: + ordered :py:obj:`list` of learning rates. + The first entry is the default learning rate and + the remaining refer to each parameter group. + NOTE: Currently, only default learning rate is supported and a single-valued list must be returned. + """ + raise NotImplementedError + + def get_last_lr(self): + r""" Return last computed learning rate by LR Scheduler""" + return self._last_lr + + def _step(self, train_step_info): + r"""Private method called to update learning rate + + NOTE: This class should never be called by the user. + """ + + # Store last lr for future inquiry + new_lr = self.get_lr(train_step_info) + self._last_lr = new_lr + + # Update ORTTrainer's optimizer config instance + train_step_info.optimizer_config.lr = new_lr[0] + + +class ConstantWarmupLRScheduler(_LRScheduler): + r"""Constant warmup strategy for learning rate update + + Learning rate update strategy: + lr = base_lr * (step / total_steps) / warmup, when step / total_steps < warmup + lr = base_lr, when step / total_steps >= warmup + + Args: + total_steps (int): total training steps for learning. + warmup (float, default is 0.002): portion of total steps for warmup. Range is (0, 1] + + Example: + .. code-block:: python + + # Initialize lr scheduler + lr_scheduler = ConstantWarmupLRScheduler(total_steps=512, warmup=0.002) + + # Initialize ORTTrainer with lr scheduler + opts = ORTTrainerOptions({ + lr_scheduler: lr_scheduler + }) + ort_trainer = ORTTrainer(..., options=opts) + + # Call step() in every batch update + for inputs in batch_inputs: + outputs = ort_trainer.train_step(**inputs) + """ + + def __init__(self, total_steps, warmup=0.002): + super().__init__() + assert isinstance(total_steps, int) and total_steps > 0,\ + "total_steps must be a strict positive number" + assert isinstance(warmup, float) and warmup >= 0 and warmup < 1,\ + "warmup must be a float between (0, 1]" + assert total_steps > warmup,\ + "total_steps must be greater than warmup" + + self.total_steps = total_steps + self.warmup = warmup + + def _warmup_constant(self, train_step_info): + # Adds 1 to train_step_info.step and self.total_steps to prevent zero'ing lr + x = (train_step_info.step + 1) / (self.total_steps + 1) + if x < self.warmup: + return x/self.warmup + return 1.0 + + def get_lr(self, train_step_info): + warmup = self._warmup_constant(train_step_info) + return [train_step_info.optimizer_config.lr * warmup] + + +class CosineWarmupLRScheduler(_LRScheduler): + r"""Cosine warmup strategy for learning rate update + + Learning rate update strategy: + lr = base_lr * (step / total_steps) / warmup, when step / total_steps < warmup + lr = base_lr * 0.5 * (1.0 + cosine(pi * (step / total_steps))), when step / total_steps >= warmup + + Args: + total_steps (int): total training steps for learning. + warmup (float, default is 0.002): portion of total steps for warmup. Range is (0, 1] + + Example: + .. code-block:: python + + # Initialize lr scheduler + lr_scheduler = CosineWarmupLRScheduler(total_steps=512, warmup=0.002) + + # Initialize ORTTrainer with lr scheduler + opts = ORTTrainerOptions({ + lr_scheduler: lr_scheduler + }) + ort_trainer = ORTTrainer(..., options=opts) + + # Call step() in every batch update + for inputs in batch_inputs: + outputs = ort_trainer.train_step(**inputs) + """ + + def __init__(self, total_steps, warmup=0.002): + super().__init__() + assert isinstance(total_steps, int) and total_steps > 0,\ + "total_steps must be a strict positive number" + assert isinstance(warmup, float) and warmup >= 0 and warmup < 1,\ + "warmup must be a float between (0, 1]" + assert total_steps > warmup,\ + "total_steps must be greater than warmup" + + self.total_steps = total_steps + self.warmup = warmup + + def _warmup_cosine(self, train_step_info): + # Adds 1 to train_step_info.step and self.total_steps to prevent zero'ing lr + x = (train_step_info.step + 1) / (self.total_steps + 1) + if x < self.warmup: + return x/self.warmup + return 0.5 * (1.0 + math.cos(math.pi * x)) + + def get_lr(self, train_step_info): + return [train_step_info.optimizer_config.lr * self._warmup_cosine(train_step_info)] + + +class LinearWarmupLRScheduler(_LRScheduler): + r"""Linear warmup strategy for learning rate update + + Learning rate update strategy: + lr = base_lr * (step / total_steps) / warmup, when step / total_steps < warmup + lr = base_lr * max(((step / total_steps) - 1.) / (warmup - 1.), 0.), when step / total_steps >= warmup + + Args: + total_steps (int): total training steps for learning. + warmup (float, default is 0.002): portion of total steps for warmup. Range is (0, 1] + + Example: + .. code-block:: python + + # Initialize lr scheduler + lr_scheduler = LinearWarmupLRScheduler(total_steps=512, warmup=0.002) + + # Initialize ORTTrainer with lr scheduler + opts = ORTTrainerOptions({ + lr_scheduler: lr_scheduler + }) + ort_trainer = ORTTrainer(..., options=opts) + + # Call step() in every batch update + for inputs in batch_inputs: + outputs = ort_trainer.train_step(**inputs) + """ + + def __init__(self, total_steps, warmup=0.002): + super().__init__() + assert isinstance(total_steps, int) and total_steps > 0,\ + "total_steps must be a strict positive number" + assert isinstance(warmup, float) and warmup >= 0 and warmup < 1,\ + "warmup must be a float between (0, 1]" + assert total_steps > warmup,\ + "total_steps must be greater than warmup" + + self.total_steps = total_steps + self.warmup = warmup + + def _warmup_linear(self, train_step_info): + # Adds 1 to train_step_info.step and self.total_steps to prevent zero'ing lr + x = (train_step_info.step + 1) / (self.total_steps + 1) + if x < self.warmup: + return x / self.warmup + return max((x - 1.) / (self.warmup - 1.), 0.) + + def get_lr(self, train_step_info): + return [train_step_info.optimizer_config.lr * self._warmup_linear(train_step_info)] + + +class PolyWarmupLRScheduler(_LRScheduler): + r"""Polynomial warmup strategy for learning rate update + + Learning rate update strategy: + lr = base_lr * (step / total_steps) / warmup, when step / total_steps < warmup + lr = base_lr * (1 − step / total_steps ) ^ degree, when step / total_steps >= warmup + + Args: + total_steps (int): total training steps for learning. + warmup (float, default is 0.002): portion of total steps for warmup. Range is (0, 1] + degree (float, default is 0.5): polynomial power + + Example: + .. code-block:: python + + # Initialize lr scheduler + lr_scheduler = PolyWarmupLRScheduler(total_steps=512, warmup=0.002, degree=0.5) + + # Initialize ORTTrainer with lr scheduler + opts = ORTTrainerOptions({ + lr_scheduler: lr_scheduler + }) + ort_trainer = ORTTrainer(..., options=opts) + + # Call step() in every batch update + for inputs in batch_inputs: + outputs = ort_trainer.train_step(**inputs) + """ + + def __init__(self, total_steps, warmup=0.002, degree=0.5): + super().__init__() + assert isinstance(total_steps, int) and total_steps > 0,\ + "total_steps must be a strict positive number" + assert isinstance(warmup, float) and warmup >= 0 and warmup < 1,\ + "warmup must be a float between (0, 1]" + assert total_steps > warmup,\ + "total_steps must be greater than warmup" + assert isinstance(degree, float) and warmup >= 0,\ + "degree must be a positive float" + + self.total_steps = total_steps + self.warmup = warmup + self.degree = degree + + def _warmup_poly(self, train_step_info): + # Adds 1 to train_step_info.step and self.total_steps to prevent zero'ing lr + x = (train_step_info.step + 1) / (self.total_steps + 1) + if x < self.warmup: + return x/self.warmup + return (1.0 - x)**self.degree + + def get_lr(self, train_step_info): + return [train_step_info.optimizer_config.lr * self._warmup_poly(train_step_info)] diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py index c684e7f4736ff..31d47db5f49e7 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py @@ -4,7 +4,7 @@ from onnxruntime.capi.training import orttrainer_options as orttrainer_options from onnxruntime.capi.training import model_desc_validation as md_val -from onnxruntime.capi.training import orttrainer, amp, optim +from onnxruntime.capi.training import orttrainer, amp, optim, TrainStepInfo @pytest.mark.parametrize("test_input", [ @@ -80,19 +80,22 @@ def testORTTrainerModelDescValidSchemas(test_input): @pytest.mark.parametrize("test_input,error_msg", [ ({'inputs': [(True, [])], 'outputs': [(True, [])]}, - "Invalid model_desc: {'inputs': [{0: ['the first element of the tuple (aka name) must be a string']}], 'outputs': [{0: ['the first element of the tuple (aka name) must be a string']}]}"), + "Invalid model_desc: {'inputs': [{0: ['the first element of the tuple (aka name) must be a string']}], " + "'outputs': [{0: ['the first element of the tuple (aka name) must be a string']}]}"), ({'inputs': [('in1', None)], 'outputs': [('out1', None)]}, - "Invalid model_desc: {'inputs': [{0: ['the second element of the tuple (aka shape) must be a list']}], 'outputs': [{0: ['the second element of the tuple (aka shape) must be a list']}]}"), + "Invalid model_desc: {'inputs': [{0: ['the second element of the tuple (aka shape) must be a list']}], " + "'outputs': [{0: ['the second element of the tuple (aka shape) must be a list']}]}"), ({'inputs': [('in1', [])], - 'outputs': [('out1', [], None)]}, + 'outputs': [('out1', [], None)]}, "Invalid model_desc: {'outputs': [{0: ['the third element of the tuple (aka is_loss) must be a boolean']}]}"), ({'inputs': [('in1', [True])], 'outputs': [('out1', [True])]}, - "Invalid model_desc: {'inputs': [{0: ['each shape must be either a string or integer']}], 'outputs': [{0: ['each shape must be either a string or integer']}]}"), + "Invalid model_desc: {'inputs': [{0: ['each shape must be either a string or integer']}], " + "'outputs': [{0: ['each shape must be either a string or integer']}]}"), ({'inputs': [('in1', [])], 'outputs': [('out1', [], True), ('out2', [], True)]}, - "Invalid model_desc: {'outputs': [{1: ['only one is_loss can bet set to True']}]}"), + "Invalid model_desc: {'outputs': [{1: ['only one is_loss can bet set to True']}]}"), ]) def testORTTrainerModelDescInvalidSchemas(test_input, error_msg): r''' Test different ways of using default values for incomplete input''' @@ -356,3 +359,48 @@ def testInvalidParamparams(optim_name): else: raise ValueError('invalid input') assert str(e.value) == "'lr' is not supported inside params" + + +def testLinearLRSchedulerCreation(): + total_steps = 10 + warmup = 0.05 + + lr_scheduler = optim.lr_scheduler.LinearWarmupLRScheduler(total_steps, + warmup) + + # Initial state + assert lr_scheduler.total_steps == total_steps + assert lr_scheduler.warmup == warmup + + +@pytest.mark.parametrize("lr_scheduler,expected_values", [ + (optim.lr_scheduler.ConstantWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843, + 0.023843, 0.023843, 0.023843, 0.023843, 0.023843]), + (optim.lr_scheduler.CosineWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843, + 0.010225, 0.002989, 0.0005158, 0.000040937, 0.0000008291]), + (optim.lr_scheduler.LinearWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843, + 0.021675, 0.0157636, 0.0085983, 0.0031266, 0.00056847]), + (optim.lr_scheduler.PolyWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843, + 0.0160749, 0.0096935, 0.0050622, 0.0021585, 0.000650833]) +]) +def testLRSchedulerUpdateImpl(lr_scheduler, expected_values): + rtol = 1e-04 + + # Initial state + initial_lr = 1 + total_steps = 10 + warmup = 0.5 + optimizer_config = optim.SGDConfig(lr=initial_lr) + lr_scheduler = lr_scheduler(total_steps, + warmup) + + # First half is warmup + for step in range(total_steps): + # Emulate ORTTRainer.train_step() call that updates its train_step_info + train_step_info = TrainStepInfo(step=step, optimizer_config=optimizer_config) + + lr_scheduler._step(train_step_info) + lr_list = lr_scheduler.get_last_lr() + assert len(lr_list) == 1 + assert_allclose(lr_list[0], + expected_values[step], rtol=rtol, err_msg="lr mismatch")