Skip to content

Commit

Permalink
Add LRScheduler implementation (#4357)
Browse files Browse the repository at this point in the history
  • Loading branch information
Thiago Crepaldi committed Aug 15, 2020
1 parent da8b45b commit afcdc57
Show file tree
Hide file tree
Showing 3 changed files with 315 additions and 7 deletions.
2 changes: 2 additions & 0 deletions orttraining/orttraining/python/training/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .config import _OptimizerConfig, AdamConfig, LambConfig, SGDConfig
from .lr_scheduler import _LRScheduler, ConstantWarmupLRScheduler, CosineWarmupLRScheduler,\
LinearWarmupLRScheduler, PolyWarmupLRScheduler
260 changes: 259 additions & 1 deletion orttraining/orttraining/python/training/optim/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,260 @@
import math


class _LRScheduler(object):
pass
r"""Base class for implementing custom learning rate schedulers
Schedulers can be either stateful or stateless.
Stateless implementation can only rely on information available at
:py:class:`.TrainStepInfo`.
Stateful implementation, on the other hand, can store additional parameters
by overriding the constructor.
In both cases, once the scheduler is configured, no user code is needed
to update learning rate during each train step.
NOTE: Current implementation doesn't support 'lr' within :py:attr:`param_groups` entries.
"""

def __init__(self):
self._last_lr = []

def get_lr(self, train_step_info):
r"""Returns a list of learning rate
Args:
train_step_info (:py:class:`.TrainStepInfo`): runtime info for current training step
Returns:
ordered :py:obj:`list` of learning rates.
The first entry is the default learning rate and
the remaining refer to each parameter group.
NOTE: Currently, only default learning rate is supported and a single-valued list must be returned.
"""
raise NotImplementedError

def get_last_lr(self):
r""" Return last computed learning rate by LR Scheduler"""
return self._last_lr

def _step(self, train_step_info):
r"""Private method called to update learning rate
NOTE: This class should never be called by the user.
"""

# Store last lr for future inquiry
new_lr = self.get_lr(train_step_info)
self._last_lr = new_lr

# Update ORTTrainer's optimizer config instance
train_step_info.optimizer_config.lr = new_lr[0]


class ConstantWarmupLRScheduler(_LRScheduler):
r"""Constant warmup strategy for learning rate update
Learning rate update strategy:
lr = base_lr * (step / total_steps) / warmup, when step / total_steps < warmup
lr = base_lr, when step / total_steps >= warmup
Args:
total_steps (int): total training steps for learning.
warmup (float, default is 0.002): portion of total steps for warmup. Range is (0, 1]
Example:
.. code-block:: python
# Initialize lr scheduler
lr_scheduler = ConstantWarmupLRScheduler(total_steps=512, warmup=0.002)
# Initialize ORTTrainer with lr scheduler
opts = ORTTrainerOptions({
lr_scheduler: lr_scheduler
})
ort_trainer = ORTTrainer(..., options=opts)
# Call step() in every batch update
for inputs in batch_inputs:
outputs = ort_trainer.train_step(**inputs)
"""

def __init__(self, total_steps, warmup=0.002):
super().__init__()
assert isinstance(total_steps, int) and total_steps > 0,\
"total_steps must be a strict positive number"
assert isinstance(warmup, float) and warmup >= 0 and warmup < 1,\
"warmup must be a float between (0, 1]"
assert total_steps > warmup,\
"total_steps must be greater than warmup"

self.total_steps = total_steps
self.warmup = warmup

def _warmup_constant(self, train_step_info):
# Adds 1 to train_step_info.step and self.total_steps to prevent zero'ing lr
x = (train_step_info.step + 1) / (self.total_steps + 1)
if x < self.warmup:
return x/self.warmup
return 1.0

def get_lr(self, train_step_info):
warmup = self._warmup_constant(train_step_info)
return [train_step_info.optimizer_config.lr * warmup]


class CosineWarmupLRScheduler(_LRScheduler):
r"""Cosine warmup strategy for learning rate update
Learning rate update strategy:
lr = base_lr * (step / total_steps) / warmup, when step / total_steps < warmup
lr = base_lr * 0.5 * (1.0 + cosine(pi * (step / total_steps))), when step / total_steps >= warmup
Args:
total_steps (int): total training steps for learning.
warmup (float, default is 0.002): portion of total steps for warmup. Range is (0, 1]
Example:
.. code-block:: python
# Initialize lr scheduler
lr_scheduler = CosineWarmupLRScheduler(total_steps=512, warmup=0.002)
# Initialize ORTTrainer with lr scheduler
opts = ORTTrainerOptions({
lr_scheduler: lr_scheduler
})
ort_trainer = ORTTrainer(..., options=opts)
# Call step() in every batch update
for inputs in batch_inputs:
outputs = ort_trainer.train_step(**inputs)
"""

def __init__(self, total_steps, warmup=0.002):
super().__init__()
assert isinstance(total_steps, int) and total_steps > 0,\
"total_steps must be a strict positive number"
assert isinstance(warmup, float) and warmup >= 0 and warmup < 1,\
"warmup must be a float between (0, 1]"
assert total_steps > warmup,\
"total_steps must be greater than warmup"

self.total_steps = total_steps
self.warmup = warmup

def _warmup_cosine(self, train_step_info):
# Adds 1 to train_step_info.step and self.total_steps to prevent zero'ing lr
x = (train_step_info.step + 1) / (self.total_steps + 1)
if x < self.warmup:
return x/self.warmup
return 0.5 * (1.0 + math.cos(math.pi * x))

def get_lr(self, train_step_info):
return [train_step_info.optimizer_config.lr * self._warmup_cosine(train_step_info)]


class LinearWarmupLRScheduler(_LRScheduler):
r"""Linear warmup strategy for learning rate update
Learning rate update strategy:
lr = base_lr * (step / total_steps) / warmup, when step / total_steps < warmup
lr = base_lr * max(((step / total_steps) - 1.) / (warmup - 1.), 0.), when step / total_steps >= warmup
Args:
total_steps (int): total training steps for learning.
warmup (float, default is 0.002): portion of total steps for warmup. Range is (0, 1]
Example:
.. code-block:: python
# Initialize lr scheduler
lr_scheduler = LinearWarmupLRScheduler(total_steps=512, warmup=0.002)
# Initialize ORTTrainer with lr scheduler
opts = ORTTrainerOptions({
lr_scheduler: lr_scheduler
})
ort_trainer = ORTTrainer(..., options=opts)
# Call step() in every batch update
for inputs in batch_inputs:
outputs = ort_trainer.train_step(**inputs)
"""

def __init__(self, total_steps, warmup=0.002):
super().__init__()
assert isinstance(total_steps, int) and total_steps > 0,\
"total_steps must be a strict positive number"
assert isinstance(warmup, float) and warmup >= 0 and warmup < 1,\
"warmup must be a float between (0, 1]"
assert total_steps > warmup,\
"total_steps must be greater than warmup"

self.total_steps = total_steps
self.warmup = warmup

def _warmup_linear(self, train_step_info):
# Adds 1 to train_step_info.step and self.total_steps to prevent zero'ing lr
x = (train_step_info.step + 1) / (self.total_steps + 1)
if x < self.warmup:
return x / self.warmup
return max((x - 1.) / (self.warmup - 1.), 0.)

def get_lr(self, train_step_info):
return [train_step_info.optimizer_config.lr * self._warmup_linear(train_step_info)]


class PolyWarmupLRScheduler(_LRScheduler):
r"""Polynomial warmup strategy for learning rate update
Learning rate update strategy:
lr = base_lr * (step / total_steps) / warmup, when step / total_steps < warmup
lr = base_lr * (1 − step / total_steps ) ^ degree, when step / total_steps >= warmup
Args:
total_steps (int): total training steps for learning.
warmup (float, default is 0.002): portion of total steps for warmup. Range is (0, 1]
degree (float, default is 0.5): polynomial power
Example:
.. code-block:: python
# Initialize lr scheduler
lr_scheduler = PolyWarmupLRScheduler(total_steps=512, warmup=0.002, degree=0.5)
# Initialize ORTTrainer with lr scheduler
opts = ORTTrainerOptions({
lr_scheduler: lr_scheduler
})
ort_trainer = ORTTrainer(..., options=opts)
# Call step() in every batch update
for inputs in batch_inputs:
outputs = ort_trainer.train_step(**inputs)
"""

def __init__(self, total_steps, warmup=0.002, degree=0.5):
super().__init__()
assert isinstance(total_steps, int) and total_steps > 0,\
"total_steps must be a strict positive number"
assert isinstance(warmup, float) and warmup >= 0 and warmup < 1,\
"warmup must be a float between (0, 1]"
assert total_steps > warmup,\
"total_steps must be greater than warmup"
assert isinstance(degree, float) and warmup >= 0,\
"degree must be a positive float"

self.total_steps = total_steps
self.warmup = warmup
self.degree = degree

def _warmup_poly(self, train_step_info):
# Adds 1 to train_step_info.step and self.total_steps to prevent zero'ing lr
x = (train_step_info.step + 1) / (self.total_steps + 1)
if x < self.warmup:
return x/self.warmup
return (1.0 - x)**self.degree

def get_lr(self, train_step_info):
return [train_step_info.optimizer_config.lr * self._warmup_poly(train_step_info)]
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from onnxruntime.capi.training import orttrainer_options as orttrainer_options
from onnxruntime.capi.training import model_desc_validation as md_val
from onnxruntime.capi.training import orttrainer, amp, optim
from onnxruntime.capi.training import orttrainer, amp, optim, TrainStepInfo


@pytest.mark.parametrize("test_input", [
Expand Down Expand Up @@ -80,19 +80,22 @@ def testORTTrainerModelDescValidSchemas(test_input):
@pytest.mark.parametrize("test_input,error_msg", [
({'inputs': [(True, [])],
'outputs': [(True, [])]},
"Invalid model_desc: {'inputs': [{0: ['the first element of the tuple (aka name) must be a string']}], 'outputs': [{0: ['the first element of the tuple (aka name) must be a string']}]}"),
"Invalid model_desc: {'inputs': [{0: ['the first element of the tuple (aka name) must be a string']}], "
"'outputs': [{0: ['the first element of the tuple (aka name) must be a string']}]}"),
({'inputs': [('in1', None)],
'outputs': [('out1', None)]},
"Invalid model_desc: {'inputs': [{0: ['the second element of the tuple (aka shape) must be a list']}], 'outputs': [{0: ['the second element of the tuple (aka shape) must be a list']}]}"),
"Invalid model_desc: {'inputs': [{0: ['the second element of the tuple (aka shape) must be a list']}], "
"'outputs': [{0: ['the second element of the tuple (aka shape) must be a list']}]}"),
({'inputs': [('in1', [])],
'outputs': [('out1', [], None)]},
'outputs': [('out1', [], None)]},
"Invalid model_desc: {'outputs': [{0: ['the third element of the tuple (aka is_loss) must be a boolean']}]}"),
({'inputs': [('in1', [True])],
'outputs': [('out1', [True])]},
"Invalid model_desc: {'inputs': [{0: ['each shape must be either a string or integer']}], 'outputs': [{0: ['each shape must be either a string or integer']}]}"),
"Invalid model_desc: {'inputs': [{0: ['each shape must be either a string or integer']}], "
"'outputs': [{0: ['each shape must be either a string or integer']}]}"),
({'inputs': [('in1', [])],
'outputs': [('out1', [], True), ('out2', [], True)]},
"Invalid model_desc: {'outputs': [{1: ['only one is_loss can bet set to True']}]}"),
"Invalid model_desc: {'outputs': [{1: ['only one is_loss can bet set to True']}]}"),
])
def testORTTrainerModelDescInvalidSchemas(test_input, error_msg):
r''' Test different ways of using default values for incomplete input'''
Expand Down Expand Up @@ -356,3 +359,48 @@ def testInvalidParamparams(optim_name):
else:
raise ValueError('invalid input')
assert str(e.value) == "'lr' is not supported inside params"


def testLinearLRSchedulerCreation():
total_steps = 10
warmup = 0.05

lr_scheduler = optim.lr_scheduler.LinearWarmupLRScheduler(total_steps,
warmup)

# Initial state
assert lr_scheduler.total_steps == total_steps
assert lr_scheduler.warmup == warmup


@pytest.mark.parametrize("lr_scheduler,expected_values", [
(optim.lr_scheduler.ConstantWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.023843, 0.023843, 0.023843, 0.023843, 0.023843]),
(optim.lr_scheduler.CosineWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.010225, 0.002989, 0.0005158, 0.000040937, 0.0000008291]),
(optim.lr_scheduler.LinearWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.021675, 0.0157636, 0.0085983, 0.0031266, 0.00056847]),
(optim.lr_scheduler.PolyWarmupLRScheduler, [0.181818, 0.066116, 0.036063, 0.026228, 0.023843,
0.0160749, 0.0096935, 0.0050622, 0.0021585, 0.000650833])
])
def testLRSchedulerUpdateImpl(lr_scheduler, expected_values):
rtol = 1e-04

# Initial state
initial_lr = 1
total_steps = 10
warmup = 0.5
optimizer_config = optim.SGDConfig(lr=initial_lr)
lr_scheduler = lr_scheduler(total_steps,
warmup)

# First half is warmup
for step in range(total_steps):
# Emulate ORTTRainer.train_step() call that updates its train_step_info
train_step_info = TrainStepInfo(step=step, optimizer_config=optimizer_config)

lr_scheduler._step(train_step_info)
lr_list = lr_scheduler.get_last_lr()
assert len(lr_list) == 1
assert_allclose(lr_list[0],
expected_values[step], rtol=rtol, err_msg="lr mismatch")

0 comments on commit afcdc57

Please sign in to comment.