Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sanity checks for second order backpropagation #60

Merged
merged 3 commits into from
Apr 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions backpack/core/derivatives/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from torch.nn import (
AvgPool2d,
Conv2d,
CrossEntropyLoss,
MSELoss,
Dropout,
Linear,
MaxPool2d,
Expand All @@ -12,6 +14,8 @@

from .avgpool2d import AvgPool2DDerivatives
from .conv2d import Conv2DDerivatives
from .crossentropyloss import CrossEntropyLossDerivatives
from .mseloss import MSELossDerivatives
from .dropout import DropoutDerivatives
from .linear import LinearDerivatives
from .maxpool2d import MaxPool2DDerivatives
Expand All @@ -30,4 +34,6 @@
ReLU: ReLUDerivatives,
Tanh: TanhDerivatives,
Sigmoid: SigmoidDerivatives,
CrossEntropyLoss: CrossEntropyLossDerivatives,
MSELoss: MSELossDerivatives,
}
39 changes: 36 additions & 3 deletions backpack/core/derivatives/basederivatives.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

from backpack.core.derivatives import shape_check


Expand Down Expand Up @@ -269,13 +271,12 @@ def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):


class BaseLossDerivatives(BaseDerivatives):
"""Second- order partial derivatives of loss functions.

"""
"""Second- order partial derivatives of loss functions."""

# TODO Add shape check
def sqrt_hessian(self, module, g_inp, g_out):
"""Symmetric factorization ('sqrt') of the loss Hessian."""
self.check_2nd_order_make_sense(module, g_inp, g_out)
return self._sqrt_hessian(module, g_inp, g_out)

def _sqrt_hessian(self, module, g_inp, g_out):
Expand All @@ -284,6 +285,7 @@ def _sqrt_hessian(self, module, g_inp, g_out):
# TODO Add shape check
def sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1):
"""Monte-Carlo sampled symmetric factorization of the loss Hessian."""
self.check_2nd_order_make_sense(module, g_inp, g_out)
return self._sqrt_hessian_sampled(module, g_inp, g_out, mc_samples=mc_samples)

def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1):
Expand All @@ -296,6 +298,7 @@ def make_hessian_mat_prod(self, module, g_inp, g_out):

Return a function that maps mat to H * mat.
"""
self.check_2nd_order_make_sense(module, g_inp, g_out)
return self._make_hessian_mat_prod(module, g_inp, g_out)

def _make_hessian_mat_prod(self, module, g_inp, g_out):
Expand All @@ -304,7 +307,37 @@ def _make_hessian_mat_prod(self, module, g_inp, g_out):
# TODO Add shape check
def sum_hessian(self, module, g_inp, g_out):
"""Loss Hessians, summed over the batch dimension."""
self.check_2nd_order_make_sense(module, g_inp, g_out)
return self._sum_hessian(module, g_inp, g_out)

def _sum_hessian(self, module, g_inp, g_out):
raise NotImplementedError

def check_2nd_order_make_sense(self, module, g_inp, g_out):
"""Verify conditions for 2nd-order extensions to be working.

2nd-order extensions are only guaranteed to work if the `loss`,
on which `backward()` is called, is a scalar that has not been
modified further after passing through the loss function module.
"""
self._check_output_is_scalar(module)
self._check_loss_has_not_been_modified(module, g_out)

def _check_output_is_scalar(self, module):
"""Raise an exception is the module output is not a scalar."""
if module.output.numel() != 1:
raise ValueError(
"Output must be scalar. Got {}".format(module.output.shape)
)

def _check_loss_has_not_been_modified(self, module, g_out):
"""Raise a warning if the module output seems to have been changed."""
grad_out_is_identity = g_out is None or (g_out[0] == 1.0).all().item()
if not grad_out_is_identity:
warnings.warn(
"The output of {} seems to have been modified.".format(module)
+ " Backpack might give wrong second-order information."
+ " Make sure you call backward() on the output of a loss"
+ " function module from torch.nn",
UserWarning,
)
53 changes: 48 additions & 5 deletions backpack/core/derivatives/crossentropyloss.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Partial derivatives for cross-entropy loss."""
from math import sqrt

from torch import diag, diag_embed, einsum, multinomial, ones_like, softmax
Expand All @@ -9,11 +10,20 @@


class CrossEntropyLossDerivatives(BaseLossDerivatives):
"""Partial derivatives for cross-entropy loss.

The `torch.nn.CrossEntropyLoss` operation is a composition of softmax
and negative log-likelihood.
"""

def get_module(self):
"""Return the `torch.nn` module for cross-entropy loss."""
return CrossEntropyLoss

def _sqrt_hessian(self, module, g_inp, g_out):
probs = self.get_probs(module)
self._check_2nd_order_parameters(module)

probs = self._get_probs(module)
tau = torchsqrt(probs)
V_dim, C_dim = 0, 2
Id = diag_embed(ones_like(probs), dim1=V_dim, dim2=C_dim)
Expand All @@ -27,10 +37,12 @@ def _sqrt_hessian(self, module, g_inp, g_out):
return sqrt_H

def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1):
self._check_2nd_order_parameters(module)

M = mc_samples
C = module.input0.shape[1]

probs = self.get_probs(module)
probs = self._get_probs(module)
V_dim = 0
probs_unsqueezed = probs.unsqueeze(V_dim).repeat(M, 1, 1)

Expand All @@ -47,7 +59,9 @@ def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1):
return sqrt_mc_h

def _sum_hessian(self, module, g_inp, g_out):
probs = self.get_probs(module)
self._check_2nd_order_parameters(module)

probs = self._get_probs(module)
sum_H = diag(probs.sum(0)) - einsum("bi,bj->ij", (probs, probs))

if module.reduction == "mean":
Expand All @@ -58,7 +72,9 @@ def _sum_hessian(self, module, g_inp, g_out):

def _make_hessian_mat_prod(self, module, g_inp, g_out):
"""Multiplication of the input Hessian with a matrix."""
probs = self.get_probs(module)
self._check_2nd_order_parameters(module)

probs = self._get_probs(module)

def hessian_mat_prod(mat):
Hmat = einsum("bi,cbi->cbi", (probs, mat)) - einsum(
Expand All @@ -74,7 +90,34 @@ def hessian_mat_prod(mat):
return hessian_mat_prod

def hessian_is_psd(self):
"""Return whether cross-entropy loss Hessian is positive semi-definite."""
return True

def get_probs(self, module):
def _get_probs(self, module):
return softmax(module.input0, dim=1)

def _check_2nd_order_parameters(self, module):
"""Verify that the parameters are supported by 2nd-order quantities.

Attributes:
module (torch.nn.CrossEntropyLoss): Extended CrossEntropyLoss module

Raises:
NotImplementedError: If module's setting is not implemented.
"""
implemented_ignore_index = -100
implemented_weight = None

if module.ignore_index != implemented_ignore_index:
raise NotImplementedError(
"Only default ignore_index ({}) is implemented, got {}".format(
implemented_ignore_index, module.ignore_index
)
)

if module.weight != implemented_weight:
raise NotImplementedError(
"Only default weight ({}) is implemented, got {}".format(
implemented_weight, module.weight
)
)
6 changes: 5 additions & 1 deletion backpack/core/derivatives/mseloss.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def _sqrt_hessian(self, module, g_inp, g_out):
V_dim, C_dim = 0, 2
diag = sqrt(2) * ones_like(module.input0)
sqrt_H = diag_embed(diag, dim1=V_dim, dim2=C_dim)
N = module.input0_shape[0]

if module.reduction == "mean":
N = module.input0.shape[0]
sqrt_H /= sqrt(N)

return sqrt_H
Expand Down Expand Up @@ -74,6 +74,10 @@ def hessian_mat_prod(mat):
def check_input_dims(self, module):
if not len(module.input0.shape) == 2:
raise ValueError("Only 2D inputs are currently supported for MSELoss.")
if not module.input0.shape[1] == 1:
raise NotImplementedError(
"MSE between batches of vectors is not implemented yet."
)

def hessian_is_psd(self):
return True
94 changes: 94 additions & 0 deletions test/test_second_order_warnings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Checks wether backpack correctly recognizes when 2nd order backprop would fail.

Failure cases include
- the loss if not the output of a torch.nn module
- using unsupported parameters of the loss
"""

import pytest
import torch
from torch.nn import CrossEntropyLoss, MSELoss
from backpack import extend
from backpack import backpack as bp
import backpack.extensions as bpext

ext_2nd_order = [
bpext.KFAC,
bpext.KFRA,
bpext.KFLR,
bpext.DiagGGNExact,
bpext.DiagGGNMC,
]
ext_2nd_order_name = [
"KFAC",
"KFRA",
"KFLR",
"DiagGGNExact",
"DiagGGNExactMC",
]


def classification_targets(N, num_classes):
"""Create random targets for classes 0, ..., `num_classes - 1`."""
return torch.randint(size=(N,), low=0, high=num_classes)


def dummy_cross_entropy(N=5):
y_pred = torch.rand((N, 2))
y_pred.requires_grad = True
y = classification_targets(N, 2)
loss_module = extend(CrossEntropyLoss())
return loss_module(y_pred, y)


def dummy_mse(N=5, D=1):
y_pred = torch.rand((N, D))
y_pred.requires_grad = True
y = torch.randn((N, D))
loss_module = extend(MSELoss())
return loss_module(y_pred, y)


@pytest.mark.parametrize("extension", ext_2nd_order, ids=ext_2nd_order_name)
def test_sqrt_hessian_crossentropy_should_pass(extension):
loss = dummy_cross_entropy()

with bp(extension()):
loss.backward()


@pytest.mark.parametrize("extension", ext_2nd_order, ids=ext_2nd_order_name)
def test_sqrt_hessian_mse_should_pass(extension):
loss = dummy_mse()

with bp(extension()):
loss.backward()


@pytest.mark.parametrize("extension", ext_2nd_order, ids=ext_2nd_order_name)
def test_sqrt_hessian_modified_crossentropy_should_fail(extension):
loss = dummy_cross_entropy() * 2

with pytest.warns(UserWarning):
with bp(extension()):
loss.backward()


@pytest.mark.parametrize("extension", ext_2nd_order, ids=ext_2nd_order_name)
def test_sqrt_hessian_modified_mse_should_fail(extension):
loss = dummy_mse() * 2

with pytest.warns(UserWarning):
with bp(extension()):
loss.backward()


@pytest.mark.parametrize("extension", ext_2nd_order, ids=ext_2nd_order_name)
def test_sqrt_hessian_mse_on_vectors_should_fail(extension):
loss = dummy_mse(D=2) * 2

with pytest.raises(
RuntimeError, match=r".*MSE between batches of vectors is not implemented yet.*"
):
with bp(extension()):
loss.backward()
Loading