diff --git a/autoPyTorch/pipeline/components/setup/optimizer/SGDOptimizer.py b/autoPyTorch/pipeline/components/setup/optimizer/SGDOptimizer.py index 4396cb381..3c310b492 100644 --- a/autoPyTorch/pipeline/components/setup/optimizer/SGDOptimizer.py +++ b/autoPyTorch/pipeline/components/setup/optimizer/SGDOptimizer.py @@ -82,6 +82,7 @@ def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, lr = UniformFloatHyperparameter('lr', lower=lr[0][0], upper=lr[0][1], default_value=lr[1], log=lr[2]) + # TODO should be refactored into l2 regularization in the future weight_decay = UniformFloatHyperparameter('weight_decay', lower=weight_decay[0][0], upper=weight_decay[0][1], default_value=weight_decay[1]) diff --git a/autoPyTorch/pipeline/components/setup/optimizer/SGDWOptimizer.py b/autoPyTorch/pipeline/components/setup/optimizer/SGDWOptimizer.py new file mode 100644 index 000000000..62e571a8e --- /dev/null +++ b/autoPyTorch/pipeline/components/setup/optimizer/SGDWOptimizer.py @@ -0,0 +1,201 @@ +from typing import Any, Callable, Dict, Iterable, Optional, Tuple + +from ConfigSpace.configuration_space import ConfigurationSpace +from ConfigSpace.hyperparameters import ( + UniformFloatHyperparameter, +) + +import numpy as np + +import torch +from torch.optim.optimizer import Optimizer + +from autoPyTorch.pipeline.components.setup.optimizer.base_optimizer import BaseOptimizerComponent + + +class SGDW(Optimizer): + r"""Implements stochastic gradient descent (optionally with momentum). + Nesterov momentum is based on the formula from + `On the importance of initialization and momentum in deep learning`__. + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float): learning rate + momentum (float, optional): momentum factor (default: 0) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + dampening (float, optional): dampening for momentum (default: 0) + nesterov (bool, optional): enables Nesterov momentum (default: False) + Example: + >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + >>> optimizer.zero_grad() + >>> loss_fn(model(input), target).backward() + >>> optimizer.step() + __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf + .. note:: + The implementation of SGD with Momentum/Nesterov subtly differs from + Sutskever et. al. and implementations in some other frameworks. + Considering the specific case of Momentum, the update can be written as + .. math:: + v = \rho * v + g \\ + p = p - lr * v + where p, g, v and :math:`\rho` denote the parameters, gradient, + velocity, and momentum respectively. + This is in contrast to Sutskever et. al. and + other frameworks which employ an update of the form + .. math:: + v = \rho * v + lr * g \\ + p = p - v + The Nesterov version is analogously modified. + """ + def __init__( + self, + params: Iterable, + lr: float, + weight_decay: float = 0, + momentum: float = 0, + dampening: float = 0, + nesterov: bool = False, + ): + if lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if momentum < 0.0: + raise ValueError("Invalid momentum value: {}".format(momentum)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict(lr=lr, momentum=momentum, dampening=dampening, + weight_decay=weight_decay, nesterov=nesterov) + if nesterov and (momentum <= 0 or dampening != 0): + raise ValueError("Nesterov momentum requires a momentum and zero dampening") + super(SGDW, self).__init__(params, defaults) + + def __setstate__(self, state): + super(SGDW, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('nesterov', False) + + def step( + self, + closure: Optional[Callable] = None + ): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + weight_decay = group['weight_decay'] + momentum = group['momentum'] + dampening = group['dampening'] + nesterov = group['nesterov'] + + for p in group['params']: + if p.grad is None: + continue + d_p = p.grad.data + + if momentum != 0: + param_state = self.state[p] + if 'momentum_buffer' not in param_state: + buf = param_state['momentum_buffer'] = torch.clone( + d_p + ).detach() + else: + buf = param_state['momentum_buffer'] + buf.mul_(momentum).add_(d_p, alpha=1 - dampening) + if nesterov: + d_p = d_p.add(momentum, buf) + else: + d_p = buf + + # Apply momentum + p.data.add_(d_p, alpha=-group['lr']) + + # Apply weight decay + if weight_decay != 0: + p.data.add_(weight_decay, alpha=-group['lr']) + return loss + + +class SGDWOptimizer(BaseOptimizerComponent): + """ + Implements Stochstic Gradient Descend algorithm. + + Args: + lr (float): learning rate (default: 1e-2) + momentum (float): momentum factor (default: 0) + weight_decay (float): weight decay (L2 penalty) (default: 0) + random_state (Optional[np.random.RandomState]): random state + """ + def __init__( + self, + lr: float, + momentum: float, + weight_decay: float, + random_state: Optional[np.random.RandomState] = None, + ): + + super().__init__() + self.lr = lr + self.momentum = momentum + self.weight_decay = weight_decay + self.random_state = random_state + + def fit(self, X: Dict[str, Any], y: Any = None) -> BaseOptimizerComponent: + """ + Fits a component by using an input dictionary with pre-requisites + + Args: + X (X: Dict[str, Any]): Dependencies needed by current component to perform fit + y (Any): not used. To comply with sklearn API + + Returns: + A instance of self + """ + + # Make sure that input dictionary X has the required + # information to fit this stage + self.check_requirements(X, y) + + self.optimizer = SGDW( + params=X['network'].parameters(), + lr=self.lr, + weight_decay=self.weight_decay, + momentum=self.momentum, + ) + + return self + + @staticmethod + def get_properties(dataset_properties: Optional[Dict[str, Any]] = None) -> Dict[str, str]: + return { + 'shortname': 'SGDW', + 'name': 'Stochastic gradient descent (optionally with momentum) with decoupled weight decay', + } + + @staticmethod + def get_hyperparameter_search_space(dataset_properties: Optional[Dict] = None, + lr: Tuple[Tuple, float, bool] = ((1e-5, 1e-1), 1e-2, True), + weight_decay: Tuple[Tuple, float] = ((0.0, 0.1), 0.0), + momentum: Tuple[Tuple, float] = ((0.0, 0.99), 0.0), + ) -> ConfigurationSpace: + + cs = ConfigurationSpace() + + # The learning rate for the model + lr = UniformFloatHyperparameter('lr', lower=lr[0][0], upper=lr[0][1], + default_value=lr[1], log=lr[2]) + + weight_decay = UniformFloatHyperparameter('weight_decay', lower=weight_decay[0][0], upper=weight_decay[0][1], + default_value=weight_decay[1]) + + momentum = UniformFloatHyperparameter('momentum', lower=momentum[0][0], upper=momentum[0][1], + default_value=momentum[1]) + + cs.add_hyperparameters([lr, weight_decay, momentum]) + + return cs diff --git a/autoPyTorch/pipeline/components/setup/optimizer/base_optimizer_choice.py b/autoPyTorch/pipeline/components/setup/optimizer/base_optimizer_choice.py index 5196f0bb7..a6a1f9bd5 100644 --- a/autoPyTorch/pipeline/components/setup/optimizer/base_optimizer_choice.py +++ b/autoPyTorch/pipeline/components/setup/optimizer/base_optimizer_choice.py @@ -137,7 +137,8 @@ def get_hyperparameter_search_space( 'AdamOptimizer', 'AdamWOptimizer', 'SGDOptimizer', - 'RMSpropOptimizer' + 'SGDWOptimizer', + 'RMSpropOptimizer', ] for default_ in defaults: if default_ in available_optimizer: