Skip to content

Commit

Permalink
Merge pull request #103 from f-dangel/release
Browse files Browse the repository at this point in the history
[Part 3, v1.2.0] Merge release into master
  • Loading branch information
f-dangel authored Oct 26, 2020
2 parents f28ed2b + 5636548 commit 6a1ac37
Show file tree
Hide file tree
Showing 173 changed files with 6,585 additions and 1,471 deletions.
7 changes: 4 additions & 3 deletions .conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ channels:
- pytorch
- defaults
dependencies:
- cudatoolkit=9.2=0
- pip=19.3.1
- python=3.7.6
- pytorch=1.3.1=py3.7_cuda9.2.148_cudnn7.6.3_0
- torchvision=0.4.2=py37_cu92
- pip:
- -r requirements.txt
- -r requirements-dev.txt
- -e .
# Note: Enabling CUDA:
# 1. CUDA Version: cat /usr/local/cuda/version.txt
# 2. Find the command: https://pytorch.org/get-started/locally/
# - pip install torch==1.5.0+cu101 torchvision==0.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
3 changes: 1 addition & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
language: python
python:
- '3.5'
- '3.6'
- '3.7'
- '3.8'
install:
- pip install -r requirements.txt
- pip install -r requirements/test.txt
- pip install .
- pip install pillow==6.1.0
cache:
- pip
script:
Expand Down
2 changes: 1 addition & 1 deletion README-dev.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# <img alt="BackPACK" src="./logo/backpack_logo_torch.svg" height="90"> BackPACK developer manual

## General standards
- Python version: support 3.5+, use 3.7 for development
- Python version: support 3.6+, use 3.7 for development
- `git` [branching model](https://nvie.com/posts/a-successful-git-branching-model/)
- Docstring style: [Google](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html)
- Test runner: [`pytest`](https://docs.pytest.org/en/latest/)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

[![Travis](https://travis-ci.org/f-dangel/backpack.svg?branch=master)](https://travis-ci.org/f-dangel/backpack)
[![Coveralls](https://coveralls.io/repos/github/f-dangel/backpack/badge.svg?branch=master)](https://coveralls.io/github/f-dangel/backpack)
[![Python 3.5+](https://img.shields.io/badge/python-3.5+-blue.svg)](https://www.python.org/downloads/release/python-350/)
[![Python 3.6+](https://img.shields.io/badge/python-3.6+-blue.svg)](https://www.python.org/downloads/release/python-360/)

BackPACK is built on top of [PyTorch](https://github.com/pytorch/pytorch). It efficiently computes quantities other than the gradient.

Expand Down
9 changes: 8 additions & 1 deletion backpack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect

import torch

from backpack.extensions.backprop_extension import BackpropExtension

from . import extensions
Expand Down Expand Up @@ -126,7 +127,13 @@ def hook_run_extensions(module, g_inp, g_out):
print("[DEBUG] Running extension", backpack_extension, "on", module)
backpack_extension.apply(module, g_inp, g_out)

if not CTX.is_extension_active(extensions.curvmatprod.CMP):
if not (
CTX.is_extension_active(
extensions.curvmatprod.HMP,
extensions.curvmatprod.GGNMP,
extensions.curvmatprod.PCHMP,
)
):
memory_cleanup(module)


Expand Down
4 changes: 2 additions & 2 deletions backpack/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def remove_hooks():
CTX.hook_handles = []

@staticmethod
def is_extension_active(extension_class):
def is_extension_active(*extension_classes):
for backpack_ext in CTX.get_active_exts():
if isinstance(backpack_ext, extension_class):
if isinstance(backpack_ext, extension_classes):
return True
return False

Expand Down
31 changes: 29 additions & 2 deletions backpack/core/derivatives/__init__.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,66 @@
from torch.nn import (
ELU,
SELU,
AvgPool2d,
Conv1d,
Conv2d,
Conv3d,
ConvTranspose1d,
ConvTranspose2d,
ConvTranspose3d,
CrossEntropyLoss,
MSELoss,
Dropout,
LeakyReLU,
Linear,
LogSigmoid,
MaxPool2d,
MSELoss,
ReLU,
Sigmoid,
Tanh,
ZeroPad2d,
)

from .avgpool2d import AvgPool2DDerivatives
from .conv1d import Conv1DDerivatives
from .conv_transpose1d import ConvTranspose1DDerivatives
from .conv2d import Conv2DDerivatives
from .conv_transpose2d import ConvTranspose2DDerivatives
from .conv3d import Conv3DDerivatives
from .conv_transpose3d import ConvTranspose3DDerivatives
from .crossentropyloss import CrossEntropyLossDerivatives
from .mseloss import MSELossDerivatives
from .dropout import DropoutDerivatives
from .elu import ELUDerivatives
from .leakyrelu import LeakyReLUDerivatives
from .linear import LinearDerivatives
from .logsigmoid import LogSigmoidDerivatives
from .maxpool2d import MaxPool2DDerivatives
from .mseloss import MSELossDerivatives
from .relu import ReLUDerivatives
from .selu import SELUDerivatives
from .sigmoid import SigmoidDerivatives
from .tanh import TanhDerivatives
from .zeropad2d import ZeroPad2dDerivatives

derivatives_for = {
Linear: LinearDerivatives,
Conv1d: Conv1DDerivatives,
Conv2d: Conv2DDerivatives,
Conv3d: Conv3DDerivatives,
AvgPool2d: AvgPool2DDerivatives,
MaxPool2d: MaxPool2DDerivatives,
ZeroPad2d: ZeroPad2dDerivatives,
Dropout: DropoutDerivatives,
ReLU: ReLUDerivatives,
Tanh: TanhDerivatives,
Sigmoid: SigmoidDerivatives,
ConvTranspose1d: ConvTranspose1DDerivatives,
ConvTranspose2d: ConvTranspose2DDerivatives,
ConvTranspose3d: ConvTranspose3DDerivatives,
LeakyReLU: LeakyReLUDerivatives,
LogSigmoid: LogSigmoidDerivatives,
ELU: ELUDerivatives,
SELU: SELUDerivatives,
CrossEntropyLoss: CrossEntropyLossDerivatives,
MSELoss: MSELossDerivatives,
}
3 changes: 2 additions & 1 deletion backpack/core/derivatives/avgpool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
convolution over single channels with a constant kernel."""

import torch.nn
from torch.nn import AvgPool2d, Conv2d, ConvTranspose2d

from backpack.core.derivatives.basederivatives import BaseDerivatives
from backpack.utils.ein import eingroup
from torch.nn import AvgPool2d, Conv2d, ConvTranspose2d


class AvgPool2DDerivatives(BaseDerivatives):
Expand Down
63 changes: 41 additions & 22 deletions backpack/core/derivatives/basederivatives.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Base classes for more flexible Jacobians and second-order information."""

import warnings

from backpack.core.derivatives import shape_check
Expand All @@ -24,12 +23,12 @@ class BaseDerivatives:
For simplicity, consider the vector case, i.e. a function which maps an
`[N, D_in]` `input` into an `[N, D_out]` `output`.
The Jacobian `J` of is tensor of shape `[N, D_out, N_in, D_in]`.
The input-output Jacobian `J` of is tensor of shape `[N, D_out, N_in, D_in]`.
Partial derivatives are ordered as
`J[i, j, k, l] = 𝜕output[i, j] / 𝜕input[k, l].
The transposed Jacobian `Jᵀ` has shape `[N, D_in, N, D_out]`.
The transposed input-output Jacobian `Jᵀ` has shape `[N, D_in, N, D_out]`.
Partial derivatives are ordered as
`Jᵀ[i, j, k, l] = 𝜕output[k, l] / 𝜕input[i, j]`.
Expand Down Expand Up @@ -67,13 +66,13 @@ def jac_mat_prod(self, module, g_inp, g_out, mat):
return self._jac_mat_prod(module, g_inp, g_out, mat)

def _jac_mat_prod(self, module, g_inp, g_out, mat):
"""Internal implementation of the Jacobian."""
"""Internal implementation of the input-output Jacobian."""
raise NotImplementedError

@shape_check.jac_t_mat_prod_accept_vectors
@shape_check.jac_t_mat_prod_check_shapes
def jac_t_mat_prod(self, module, g_inp, g_out, mat):
"""Apply transposed Jacobian of module output w.r.t. input to a matrix.
"""Apply transposed input-ouput Jacobian of module output to a matrix.
Implicit application of Jᵀ:
result[v, ̃n, ̃c, ̃w, ...]
Expand All @@ -100,45 +99,65 @@ def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
# TODO Add shape check
# TODO Use new convention
def ea_jac_t_mat_jac_prod(self, module, g_inp, g_out, mat):
"""Expectation approximation of outer product with input-output Jacobian.
Used for backpropagation in KFRA.
For `yₙ = f(xₙ) n=1,...,n`, compute `E(Jₙᵀ mat Jₙ) = 1/n ∑ₙ Jₙᵀ mat Jₙ`.
In index notation, let `output[n]=f(input[n]) n = 1,...,n`. Then,
`result[i,j]
= 1/n ∑ₙₖₗ (𝜕output[n,k] / 𝜕input[n,i]) mat[k,l] (𝜕output[n,j] / 𝜕input[n,l])
Args:
module (torch.nn.Module): Extended module.
g_inp ([torch.Tensor]): Gradients of the module w.r.t. its inputs.
g_out ([torch.Tensor]): Gradients of the module w.r.t. its outputs.
mat (torch.Tensor): Matrix of shape `[D_out, D_out]`.
Returns:
torch.Tensor: Matrix of shape `[D_in, D_in]`.
Note:
- This operation can be applied without knowledge about backpropagated
derivatives. Both `g_inp` and `g_out` are usually not required and
can be set to `None`.
"""
raise NotImplementedError

def hessian_is_zero(self):
raise NotImplementedError

def hessian_is_diagonal(self):
"""Is `∂²output[i] / ∂input[j] ∂input[k]` nonzero only if `i = j = k`."""
raise NotImplementedError

def hessian_diagonal(self):
"""Return `∂²output[i] / ∂input[i]²`.
Only required if `hessian_is_diagonal` returns `True`.
"""
raise NotImplementedError

def hessian_is_psd(self):
"""Is `∂²output[i] / ∂input[j] ∂input[k]` positive semidefinite (PSD)."""
raise NotImplementedError

# TODO make accept vectors
# TODO add shape check
def make_residual_mat_prod(self, module, g_inp, g_out):
"""Return multiplication routine with the residual term.
@shape_check.residual_mat_prod_accept_vectors
@shape_check.residual_mat_prod_check_shapes
def residual_mat_prod(self, module, g_inp, g_out, mat):
"""Multiply with the residual term.
The function performs the mapping: mat → [∑_{k} Hz_k(x) 𝛿z_k] mat.
(required for extension `curvmatprod`)
Performs mat → [∑_{k} Hz_k(x) 𝛿z_k] mat.
Note:
-----
This function only has to be implemented if the residual is not
zero and not diagonal (for instance, `BatchNorm`).
"""
raise NotImplementedError
return self._residual_mat_prod(module, g_inp, g_out, mat)

# TODO Refactor and remove
def batch_flat(self, tensor):
batch = tensor.size(0)
# TODO Removing the clone().detach() will destroy the computation graph
# Tests will fail
return batch, tensor.clone().detach().view(batch, -1)

# TODO Refactor and remove
def get_batch(self, module):
return module.input0.size(0)
def _residual_mat_prod(self, module, g_inp, g_out, mat):
raise NotImplementedError

@staticmethod
def _reshape_like(mat, like):
Expand Down
57 changes: 39 additions & 18 deletions backpack/core/derivatives/batchnorm1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
from torch.nn import BatchNorm1d

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.core.derivatives.shape_check import (
R_mat_prod_accept_vectors,
R_mat_prod_check_shapes,
)


class BatchNorm1dDerivatives(BaseParameterDerivatives):
Expand Down Expand Up @@ -40,7 +36,7 @@ def _jac_t_mat_prod(self, module, g_inp, g_out, mat):
"""
assert module.affine is True

N = self.get_batch(module)
N = module.input0.size(0)
x_hat, var = self.get_normalized_input_and_var(module)
ivar = 1.0 / (var + module.eps).sqrt()

Expand All @@ -59,20 +55,45 @@ def get_normalized_input_and_var(self, module):
var = input.var(dim=0, unbiased=False)
return (input - mean) / (var + module.eps).sqrt(), var

@R_mat_prod_accept_vectors
@R_mat_prod_check_shapes
def make_residual_mat_prod(self, module, g_inp, g_out):
# TODO: Implement R_mat_prod for BatchNorm
def R_mat_prod(mat):
"""Multiply with the residual: mat → [∑_{k} Hz_k(x) 𝛿z_k] mat.
def _residual_mat_prod(self, module, g_inp, g_out, mat):
"""Multiply with BatchNorm1d residual-matrix.
Second term of the module input Hessian backpropagation equation.
"""
raise NotImplementedError
Paul Fischer (GitHub: @paulkogni) contributed this code during a research
project in winter 2019.
# TODO: Enable tests in test/automated_bn_test.py
raise NotImplementedError
return R_mat_prod
Details are described in
- `TODO: Add tech report title`
<TODO: Wait for tech report upload>_
by Paul Fischer, 2020.
"""
N = module.input0.size(0)
x_hat, var = self.get_normalized_input_and_var(module)
gamma = module.weight
eps = module.eps

factor = gamma / (N * (var + eps))

sum_127 = einsum("nc,vnc->vc", (x_hat, mat))
sum_24 = einsum("nc->c", g_out[0])
sum_3 = einsum("nc,vnc->vc", (g_out[0], mat))
sum_46 = einsum("vnc->vc", mat)
sum_567 = einsum("nc,nc->c", (x_hat, g_out[0]))

r_mat = -einsum("nc,vc->vnc", (g_out[0], sum_127))
r_mat += (1.0 / N) * einsum("c,vc->vc", (sum_24, sum_127)).unsqueeze(1).expand(
-1, N, -1
)
r_mat -= einsum("nc,vc->vnc", (x_hat, sum_3))
r_mat += (1.0 / N) * einsum("nc,c,vc->vnc", (x_hat, sum_24, sum_46))

r_mat -= einsum("vnc,c->vnc", (mat, sum_567))
r_mat += (1.0 / N) * einsum("c,vc->vc", (sum_567, sum_46)).unsqueeze(1).expand(
-1, N, -1
)
r_mat += (3.0 / N) * einsum("nc,vc,c->vnc", (x_hat, sum_127, sum_567))

return einsum("c,vnc->vnc", (factor, r_mat))

def _weight_jac_mat_prod(self, module, g_inp, g_out, mat):
x_hat, _ = self.get_normalized_input_and_var(module)
Expand All @@ -90,7 +111,7 @@ def _weight_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch):
return einsum(equation, operands)

def _bias_jac_mat_prod(self, module, g_inp, g_out, mat):
N = self.get_batch(module)
N = module.input0.size(0)
return mat.unsqueeze(1).repeat(1, N, 1)

def _bias_jac_t_mat_prod(self, module, g_inp, g_out, mat, sum_batch=True):
Expand Down
6 changes: 6 additions & 0 deletions backpack/core/derivatives/conv1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from backpack.core.derivatives.convnd import ConvNDDerivatives


class Conv1DDerivatives(ConvNDDerivatives):
def __init__(self):
super().__init__(N=1)
Loading

0 comments on commit 6a1ac37

Please sign in to comment.