diff --git a/ignite/handlers/param_scheduler.py b/ignite/handlers/param_scheduler.py index 7a878ce4fd1..d0d0cba4fd8 100644 --- a/ignite/handlers/param_scheduler.py +++ b/ignite/handlers/param_scheduler.py @@ -390,6 +390,9 @@ class LinearCyclicalScheduler(CyclicalScheduler): save_history: whether to log the parameter values to `engine.state.param_history`, (default=False). param_group_index: optimizer's parameters group to use. + monotonic: whether to schedule only one half of the cycle: descending or ascending. + If True, this argument can not be used together with ``warmup_duration``. + (default=False). Note: If the scheduler is bound to an 'ITERATION_*' event, 'cycle_size' should @@ -465,12 +468,28 @@ def print_lr(): .. versionchanged:: 0.4.13 Added cyclic warm-up to the scheduler using ``warmup_duration``. + + .. versionchanged:: 0.5.0 + Added monotonic argument. """ + def __init__(self, *args: Any, monotonic: bool = False, **kwagrs: Any): + super(LinearCyclicalScheduler, self).__init__(*args, **kwagrs) + self.monotonic = monotonic + if self.warmup_duration > 0 and not self.monotonic: + raise ValueError( + "Invalid combination when warmup_duration > 0 and monotonic=False, " + "please use either set warmup_duration=0 or monotonic=True" + ) + def get_param(self) -> float: """Method to get current optimizer's parameter value""" cycle_progress = self.event_index / self.cycle_size - return self.end_value + (self.start_value - self.end_value) * abs(cycle_progress - 0.5) * 2 + + if self.monotonic: + return self.start_value + (self.end_value - self.start_value) * cycle_progress + else: + return self.end_value + (self.start_value - self.end_value) * abs(cycle_progress - 0.5) * 2 class CosineAnnealingScheduler(CyclicalScheduler): diff --git a/tests/ignite/handlers/test_param_scheduler.py b/tests/ignite/handlers/test_param_scheduler.py index eb70ab3a082..af5f5cae497 100644 --- a/tests/ignite/handlers/test_param_scheduler.py +++ b/tests/ignite/handlers/test_param_scheduler.py @@ -68,6 +68,13 @@ def test_linear_scheduler_asserts(): with pytest.raises(ValueError, match=r"Argument cycle_size should be positive and larger than 1"): LinearCyclicalScheduler(optimizer, "lr", 1, 0, cycle_size=1) + with pytest.raises( + ValueError, + match=r"Invalid combination when warmup_duration > 0 and monotonic=False, " + r"please use either set warmup_duration=0 or monotonic=True", + ): + LinearCyclicalScheduler(optimizer, "lr", 1, 0, cycle_size=2, warmup_duration=1) + def test_linear_scheduler(): tensor = torch.zeros([1], requires_grad=True) @@ -144,6 +151,102 @@ def save_lr(engine): scheduler.load_state_dict(state_dict) +def test_linear_scheduler_warmup_duration(): + tensor = torch.zeros([1], requires_grad=True) + optimizer = torch.optim.SGD([tensor], lr=0.0) + + scheduler = LinearCyclicalScheduler(optimizer, "lr", 1, 0, 10, warmup_duration=5, monotonic=True) + state_dict = scheduler.state_dict() + + def save_lr(engine): + lrs.append(optimizer.param_groups[0]["lr"]) + + trainer = Engine(lambda engine, batch: None) + trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) + trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr) + lr_values_in_cycle = [ + 1.0, + 0.9, + 0.8, + 0.7, + 0.6, + 0.5, + 0.4, + 0.3, + 0.2, + 0.1, + 0.0, + 0.2, + 0.4, + 0.6, + 0.8, + 1.0, + 0.9, + 0.8, + 0.7, + 0.6, + ] + for _ in range(2): + lrs = [] + trainer.run([0] * 10, max_epochs=2) + + assert lrs == pytest.approx(lr_values_in_cycle) + scheduler.load_state_dict(state_dict) + + optimizer = torch.optim.SGD([tensor], lr=0) + scheduler = LinearCyclicalScheduler(optimizer, "lr", 1, 0, 10, cycle_mult=2, warmup_duration=5, monotonic=True) + state_dict = scheduler.state_dict() + + trainer = Engine(lambda engine, batch: None) + trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) + trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr) + + for _ in range(2): + lrs = [] + trainer.run([0] * 10, max_epochs=3) + + assert lrs == list( + map( + pytest.approx, + [ + # Cycle 1 + 1.0, + 0.9, + 0.8, + 0.7, + 0.6, + 0.5, + 0.4, + 0.3, + 0.2, + 0.1, + 0.0, + 0.2, + 0.4, + 0.6, + 0.8, + # Cycle 2 + 1.0, + 0.95, + 0.9, + 0.85, + 0.8, + 0.75, + 0.7, + 0.65, + 0.6, + 0.55, + 0.5, + 0.45, + 0.4, + 0.35, + 0.3, + ], + ) + ) + scheduler.load_state_dict(state_dict) + + def test_linear_scheduler_cycle_size_two(): tensor = torch.zeros([1], requires_grad=True) optimizer = torch.optim.SGD([tensor], lr=0)