Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update master with code for 1.6.0 release #309

Merged
merged 30 commits into from
Jul 12, 2023
Merged

Update master with code for 1.6.0 release #309

merged 30 commits into from
Jul 12, 2023

Conversation

f-dangel
Copy link
Owner

No description provided.

jabader97 and others added 29 commits May 10, 2022 16:33
The base class automates the Hessian square root approximation via
MC-sampling. It requires specifying the likelihood distribution. The Hessian
square root is then approximated by computing gradients with targets
drawn from the likelihood with `autograd`. 

---


* Updated cross entropy and MSE to use new NLL base

* Changed _post_process and _checks to raise NotImplementedError

* Refactored NLL base for general log-prob derivative from torch.distribution, as well as overwriting for MSE and CE

* Fixed some spacing errors

* Updated doc strings

* Added test for NLL version of compute_sampled_grads

* Fixed spacing issue introduced in last commit

* Implemented (some of) commented fixes

* Added NotImplementedError to _veryify_support to fix coveralls problem

* Fixed some commenting, removed mean from mse make distribution, moved sampling into loop for nll base

* [REF] Use distribution for sampled gradients in manual approach

* Changed use_dist to use_autograd, fixed return statement error in MSE, fixed device error in MSE

* Black

* Reverted changes to CEL

* Reverted cross entropy changes

* Removed unneeded changes to clean the diff

* Some docstring updates

* A few missed changes for diff

* [REF] Move `use_autograd` inside `NLLLossDerivatives`

* [REF] Change default of `use_autograd` to `False` for `MSELoss`

* [FMT] Remove space

* [DEL] Remove `use_autograd` from `CrossEntropyLoss`

* [DOC] Clarify `use_autograd` in test function

* [FIX] Syntax error

* [CI] Add NLLLossDerivatives to fully documented

* Added missing type annotations to nll_base.py, removed redundant autograd_res call in test_sqrt_hessian_sampled_squared_approximates_hessian_nll

* Removed unnecessary ABC in MSE loss, fixed _compute_sampled_grad_manual name

* Fixed documentation to match standards and reflect current version

* Darglint fixes

* Pydocstyle fix

* Darglint formatting fix

* Removed retain_graph=True

* Created MSE_LOSS_PROBLEMS for test_sqrt_hessian_sampled_squared_approximates_hessian_nll to run on

* Added autograd test to check that sample has same shape as subsampled_input

* Reformated some too-long lines

* [REF] Remove `enable_grad` and `Variable`

* [REF] Shorten import

* [REF] Rewrite NLL test with recursion

* [FIX] Remove unused import

* [FIX] darglint

* [DOC] Polish MSELoss

* [DOC] Polish NLL base

* [DOC] Polish derivatives test

* [FIX] Type annotation

* [FIX] Darglint

* [DOC] Polish NLLbase

* [DOC] One more pass through docstrings

Co-authored-by: Felix Dangel <[email protected]>
Co-authored-by: Felix Dangel <[email protected]>
- Replace shape check of samples in main library with test
- Add `retain_grad=True` for autograd computation of sampled gradients
  (for MSELoss, it worked without `retain_graph`)

---

* [REF] Changed cross entropy loss to NLL base

* [REF] Removed arrange and rearrange, made CE work for autograd

* [REF] Changed compute_grad_manual for CE to use _make_distribution

* [REF] some cleaning

* [REF] Moved nll distribution shape check

* [FIX] darglint, isort

* [FIX] removed some unused import statements

* [REF] Remove redundant import, improve names

* [REF] Improve readability by linebreaks

* [REF] Import loss modules from `torch.nn`

* [TEST] Apply sub-sampling to input and target for shape check

* [FIX] Add tear_down call

* [DEL] Remove clone+detach

Co-authored-by: Felix Dangel <[email protected]>
* [ADD] Added BCEWithLogits loss to NLL base

* [TEST] Skip BCEWithLogitsLoss _sqrt_hessian and
_compute_sampled_grads_manual (not implemented)

* [DOC] Fix darglint

* [DEL] Remove f-string

* [TEST] Skip unimplemented methods for BCEWithLogitsLoss

* [REF] Raise NotImplementedErrors, rename output -> target

* [REF] Rename bceloss -> bcewithlogitsloss

* [REF] Less imports, type annotation, docstring polish

* [ADD] Support `reduction='sum'`

* [DEL] Remove redundant constructor

Co-authored-by: Felix Dangel <[email protected]>
* [REF] Declare `abstractmethod`s

* [REF] Ignore warning: class inheriting from `ABC` has no abstract methods

Co-authored-by: Felix Dangel <[email protected]>
Replaces
`from backpack import ..., X` <-> `from backpack import ..., extensions`
and
`with backpack(X())` <-> `with backpack(extensions.X())`

Also applies white space cleanup.

Co-authored-by: Felix Dangel <[email protected]>
* [FIX] Use batch size 1

* [DEL] Remove unused import

Co-authored-by: Felix Dangel <[email protected]>
* [CI] Test with `torch=={1.9.0, 1.10.0}`

* [CI] Test with `torch=={1.9.0, 1.11.0}`

* [FIX] flake8

* [CI] Test with `torch=={1.9.0, 1.12.0}`

* [TEST] Replace `parameters_to_vector` by custom function

This should fix
`test_network_diag_ggn[<class
'test.converter.converter_cases._Permute'>]`
in `test/converter/test_converter.py`. Between torch 1.11.0 and torch
1.12.0, the GGN-vector products for this case became non-contiguous, and
`torch.nn.utils.convert_parameters.parameters_to_vector` stopped working
as it uses `view`.

Here is a short self-contained snippet to reproduce the issue:

```python
from torch import Tensor, permute, rand, rand_like
from torch.autograd import grad
from torch.nn import Linear, Module
from torch.nn.utils.convert_parameters import parameters_to_vector

from backpack.utils.convert_parameters import tensor_list_to_vector

class Permute(Module):
    def __init__(self):
        super().__init__()
        self.batch_size = 3
        self.in_dim = (5, 3)
        out_dim = 2
        self.linear = Linear(self.in_dim[-1], out_dim)
        self.linear2 = Linear(self.in_dim[-2], out_dim)

    def forward(self, x):
        x = self.linear(x)
        x = x.permute(0, 2, 1)  # method permute
        x = self.linear2(x)
        x = permute(x, (0, 2, 1))  # function permute
        return x

    def input_fn(self) -> Tensor:
        return rand(self.batch_size, *self.in_dim)

model = Permute()

inputs = model.input_fn()
outputs = model(inputs)

params = list(model.parameters())
grad_outputs = rand_like(outputs)
v = [rand_like(p) for p in model.parameters()]

vJ_tuple = grad(outputs, params, grad_outputs=grad_outputs)

for p, vJ in zip(params, vJ_tuple):
    # all contiguous()
    print(p.shape, vJ.shape)
    # between 1.11.0 and 1.12.0, the vector-Jacobian product w.r.t. the second
    # linear layer's weight is not contiguous anymore
    print(p.is_contiguous(), vJ.is_contiguous())

vJ_vector = parameters_to_vector(vJ_tuple)

vJ_vector = tensor_list_to_vector(vJ_tuple)
```

* [REF] Use f-string and add type hints

* [REQ] Require `torch<1.13`

See #272. Waiting for
pytorch/pytorch#88312 before `torch>=1.13`
can be supported.

* [DOC] Update changelog to prepare compatibility patch

* [DOC] fix date

Co-authored-by: Felix Dangel <[email protected]>
* [ADD] Implement sqrt_hessian for BCEWithLogitsLoss

* [TEST] Add case for sqrt_hessian with non-binary labels

* [DOC] Improve derivation of Hessian square root

Co-authored-by: Felix Dangel <[email protected]>
… by default (#278)

* [ADD] Implement sqrt_hessian for BCEWithLogitsLoss

* [TEST] Add case for sqrt_hessian with non-binary labels

* [DOC] Improve derivation of Hessian square root

* [ADD] Implement manual sampled gradients for `BCEWithLogitsLoss`

* [CI] Add to fully documented files

* [DOC] Fix pydocstyle

Co-authored-by: Felix Dangel <[email protected]>
* [ADD] Implement sqrt_hessian for BCEWithLogitsLoss

* [TEST] Add case for sqrt_hessian with non-binary labels

* [DOC] Improve derivation of Hessian square root

* [ADD] Implement manual sampled gradients for `BCEWithLogitsLoss`

* [CI] Add to fully documented files

* [DOC] Fix pydocstyle

* [ADD] `DiagHessian` support for `BCEWithLogitsLoss`

* [CI] Skip `BCEWithLogitsLoss` cases in DiagGGNExactBatch

* [CI] Skip `BCEWithLogitsLoss` for `DiagGGNExact`

* [CI] Skip `BCEWithLogitsLoss` in `SqrtGGN` tests

* [REF] Use `BCEWithLogitsLoss` test cases for Hessian diagonal

* [DEL] Remove skip utilities for `BCEWithLogitsLoss`

Co-authored-by: Felix Dangel <[email protected]>
* [ADD] Implement sqrt_hessian for BCEWithLogitsLoss

* [TEST] Add case for sqrt_hessian with non-binary labels

* [DOC] Improve derivation of Hessian square root

* [ADD] Implement manual sampled gradients for `BCEWithLogitsLoss`

* [CI] Add to fully documented files

* [DOC] Fix pydocstyle

* [ADD] `DiagHessian` support for `BCEWithLogitsLoss`

* [CI] Skip `BCEWithLogitsLoss` cases in DiagGGNExactBatch

* [CI] Skip `BCEWithLogitsLoss` for `DiagGGNExact`

* [CI] Skip `BCEWithLogitsLoss` in `SqrtGGN` tests

* [REF] Use `BCEWithLogitsLoss` test cases for Hessian diagonal

* [DEL] Remove skip utilities for `BCEWithLogitsLoss`

* [ADD] Support `BCEWithLogitsLoss` in `(Diag)GGN{Exact,MC}`

Co-authored-by: Felix Dangel <[email protected]>
* [ADD] Implement sqrt_hessian for BCEWithLogitsLoss

* [TEST] Add case for sqrt_hessian with non-binary labels

* [DOC] Improve derivation of Hessian square root

* [ADD] Implement manual sampled gradients for `BCEWithLogitsLoss`

* [CI] Add to fully documented files

* [DOC] Fix pydocstyle

* [ADD] `DiagHessian` support for `BCEWithLogitsLoss`

* [CI] Skip `BCEWithLogitsLoss` cases in DiagGGNExactBatch

* [CI] Skip `BCEWithLogitsLoss` for `DiagGGNExact`

* [CI] Skip `BCEWithLogitsLoss` in `SqrtGGN` tests

* [REF] Use `BCEWithLogitsLoss` test cases for Hessian diagonal

* [DEL] Remove skip utilities for `BCEWithLogitsLoss`

* [ADD] Support `BCEWithLogitsLoss` in `(Diag)GGN{Exact,MC}`

* [ADD] Support `BCEWithLogitsLoss` in `SqrtGGN{Exact, MC}` extension

* [DEL] Forgot to extract BCEWithLogitsLoss test cases

Co-authored-by: Felix Dangel <[email protected]>
* [ADD] Implement sqrt_hessian for BCEWithLogitsLoss

* [TEST] Add case for sqrt_hessian with non-binary labels

* [DOC] Improve derivation of Hessian square root

* [ADD] Implement manual sampled gradients for `BCEWithLogitsLoss`

* [CI] Add to fully documented files

* [DOC] Fix pydocstyle

* [ADD] `DiagHessian` support for `BCEWithLogitsLoss`

* [CI] Skip `BCEWithLogitsLoss` cases in DiagGGNExactBatch

* [CI] Skip `BCEWithLogitsLoss` for `DiagGGNExact`

* [CI] Skip `BCEWithLogitsLoss` in `SqrtGGN` tests

* [REF] Use `BCEWithLogitsLoss` test cases for Hessian diagonal

* [DEL] Remove skip utilities for `BCEWithLogitsLoss`

* [ADD] Support `BCEWithLogitsLoss` in `(Diag)GGN{Exact,MC}`

* [ADD] Support `BCEWithLogitsLoss` in `SqrtGGN{Exact, MC}` extension

* [DEL] Forgot to extract BCEWithLogitsLoss test cases

* [ADD] Support `BCEWithLogitsLoss` in `KFAC`

Co-authored-by: Felix Dangel <[email protected]>
* [ADD] Implement sqrt_hessian for BCEWithLogitsLoss

* [TEST] Add case for sqrt_hessian with non-binary labels

* [DOC] Improve derivation of Hessian square root

* [ADD] Implement manual sampled gradients for `BCEWithLogitsLoss`

* [CI] Add to fully documented files

* [DOC] Fix pydocstyle

* [ADD] `DiagHessian` support for `BCEWithLogitsLoss`

* [CI] Skip `BCEWithLogitsLoss` cases in DiagGGNExactBatch

* [CI] Skip `BCEWithLogitsLoss` for `DiagGGNExact`

* [CI] Skip `BCEWithLogitsLoss` in `SqrtGGN` tests

* [REF] Use `BCEWithLogitsLoss` test cases for Hessian diagonal

* [DEL] Remove skip utilities for `BCEWithLogitsLoss`

* [ADD] Support `BCEWithLogitsLoss` in `(Diag)GGN{Exact,MC}`

* [ADD] Support `BCEWithLogitsLoss` in `SqrtGGN{Exact, MC}` extension

* [DEL] Forgot to extract BCEWithLogitsLoss test cases

* [ADD] Support `BCEWithLogitsLoss` in `KFAC`

* [TEST] Add KFLR test, split KFAC settings to recycle in KFLR tests

Co-authored-by: Felix Dangel <[email protected]>
* [REQ] Use `unfoldNd` package to unfold convolution inputs

The removed code has been extracted into a separate package.

* [DOC] Describe argument and return shape

Co-authored-by: Felix Dangel <[email protected]>
* [REQ] Use `unfoldNd` package to unfold convolution inputs

The removed code has been extracted into a separate package.

* [DOC] Describe argument and return shape

* [REF] Fully document convolution utilities

Co-authored-by: Felix Dangel <[email protected]>
…tions (#287)

* [REQ] Use `unfoldNd` package to unfold convolution inputs

The removed code has been extracted into a separate package.

* [DOC] Describe argument and return shape

* [REF] Fully document convolution utilities

* [ADD] Use `unfoldNd` to unfold input of transpose convolution

* [DEL] Remove old code for unfolding

Co-authored-by: Felix Dangel <[email protected]>
* [REQ] Use `unfoldNd` package to unfold convolution inputs

The removed code has been extracted into a separate package.

* [DOC] Describe argument and return shape

* [REF] Fully document convolution utilities

* [ADD] Use `unfoldNd` to unfold input of transpose convolution

* [DEL] Remove old code for unfolding

* [DOC] Fully-document transpose convolution utilities

Co-authored-by: Felix Dangel <[email protected]>
… convolution case (#289)

* [REQ] Use `unfoldNd` package to unfold convolution inputs

The removed code has been extracted into a separate package.

* [DOC] Describe argument and return shape

* [REF] Fully document convolution utilities

* [ADD] Use `unfoldNd` to unfold input of transpose convolution

* [DEL] Remove old code for unfolding

* [DOC] Fully-document transpose convolution utilities

* [REF] Make output shape of unfold_by_conv_transpose consistent

Co-authored-by: Felix Dangel <[email protected]>
Add documentation to the `HBP` extension of `Conv2d` layers, which is 
responsible to compute `KFAC/KFLR/KFRA`. The docstrings draw connections
to the notation in the [KFC paper](https://arxiv.org/pdf/1602.01407.pdf), and 
outline important differences, as well as improvements for consistency. 
Also add a test case for `KFAC, KFLR` for which both approximations become exact.

Note to myself: I made notes how to connect Hessian backpropagation to `KFAC` 
for convolutions by imposing a Kronecker structure on the backpropagated quantity.
This concept can also be applied to `KFRA` to achieve more consistency, but is 
currently not done by the code.

* [DOC] Fully document `HBPConv2d`

* [TEST] KFAC/KFLR for convolution with single output

Convolution layers with a single output behave like linear layers, as
the weights are not shared over the input.

* [DOC] Polish docstrings

* [TEST] Add integration test for KFRA
Generalize the Kronecker-factored approximations of Hessian diagonal blocks
(`KFRA`, `KFLR`, `KFAC`) for `Conv2d` to `Conv1d` and `Conv3d`. Add a test for the
`KFRA` approximation under specific limits.

* [DOC] Fully document `HBPConv2d`

* [TEST] KFAC/KFLR for convolution with single output

Convolution layers with a single output behave like linear layers, as
the weights are not shared over the input.

* [DOC] Polish docstrings

* [TEST] Add integration test for KFRA

* [ADD] Kronecker approximations for `ConvNd` (`N=1,2,3`)

* [DOC] Fix some typos in the docstrings

* [DOC] Simplify description of returned Kronecker proxies

* [TEST] Replace KFRA property check by value check

* [CI] Add `ConvNd` files to fully-documented

* [FIX] Typo in file name

* [FIX] Call KFRA, not KFLR
)

Generalize Kronecker approximations for convolution to transpose convolution.

* [DOC] Fully document `HBPConv2d`

* [TEST] KFAC/KFLR for convolution with single output

Convolution layers with a single output behave like linear layers, as
the weights are not shared over the input.

* [DOC] Polish docstrings

* [TEST] Add integration test for KFRA

* [ADD] Kronecker approximations for `ConvNd` (`N=1,2,3`)

* [DOC] Fix some typos in the docstrings

* [DOC] Simplify description of returned Kronecker proxies

* [TEST] Replace KFRA property check by value check

* [CI] Add `ConvNd` files to fully-documented

* [FIX] Typo in file name

* [FIX] Call KFRA, not KFLR

* [ADD] Kronecker approximations for `ConvTranspose{1,2,3}d`

- adds a test case for `ConvTranspose2d`. Note that due to the
  different index order, working with the Kronecker representation
  for weights of transpose convolutions is involved. Warn user about
  this.
- adapt tests of `KFAC, KFLR, KFRA` by adding a utility function to
  fix the index order after expanding the Kronecker product.

* [TEST] Add cases for `ConvTranspose{1,3}d`

* [FIX] pydocstyle
* [FIX] Copy `_grad_input_padding` from torch==1.9

The function was removed between torch 1.12.1 and torch 1.13.
Reintroducing it should fix
#272.

* [CI] Use latest two torch releases for tests

* [FIX] Ignore flake8 warning about abstract methods

* [FIX] Import

* [CI] Test with `torch=={1.9.0, 1.12.0}` and make tests compatible (#276)

* [CI] Test with `torch=={1.9.0, 1.10.0}`

* [CI] Test with `torch=={1.9.0, 1.11.0}`

* [FIX] flake8

* [CI] Test with `torch=={1.9.0, 1.12.0}`

* [TEST] Replace `parameters_to_vector` by custom function

This should fix
`test_network_diag_ggn[<class
'test.converter.converter_cases._Permute'>]`
in `test/converter/test_converter.py`. Between torch 1.11.0 and torch
1.12.0, the GGN-vector products for this case became non-contiguous, and
`torch.nn.utils.convert_parameters.parameters_to_vector` stopped working
as it uses `view`.

Here is a short self-contained snippet to reproduce the issue:

```python
from torch import Tensor, permute, rand, rand_like
from torch.autograd import grad
from torch.nn import Linear, Module
from torch.nn.utils.convert_parameters import parameters_to_vector

from backpack.utils.convert_parameters import tensor_list_to_vector

class Permute(Module):
    def __init__(self):
        super().__init__()
        self.batch_size = 3
        self.in_dim = (5, 3)
        out_dim = 2
        self.linear = Linear(self.in_dim[-1], out_dim)
        self.linear2 = Linear(self.in_dim[-2], out_dim)

    def forward(self, x):
        x = self.linear(x)
        x = x.permute(0, 2, 1)  # method permute
        x = self.linear2(x)
        x = permute(x, (0, 2, 1))  # function permute
        return x

    def input_fn(self) -> Tensor:
        return rand(self.batch_size, *self.in_dim)

model = Permute()

inputs = model.input_fn()
outputs = model(inputs)

params = list(model.parameters())
grad_outputs = rand_like(outputs)
v = [rand_like(p) for p in model.parameters()]

vJ_tuple = grad(outputs, params, grad_outputs=grad_outputs)

for p, vJ in zip(params, vJ_tuple):
    # all contiguous()
    print(p.shape, vJ.shape)
    # between 1.11.0 and 1.12.0, the vector-Jacobian product w.r.t. the second
    # linear layer's weight is not contiguous anymore
    print(p.is_contiguous(), vJ.is_contiguous())

vJ_vector = parameters_to_vector(vJ_tuple)

vJ_vector = tensor_list_to_vector(vJ_tuple)
```

* [REF] Use f-string and add type hints

* [REQ] Require `torch<1.13`

See #272. Waiting for
pytorch/pytorch#88312 before `torch>=1.13`
can be supported.

* [DOC] Update changelog to prepare compatibility patch

* [DOC] fix date

Co-authored-by: Felix Dangel <[email protected]>

* [CI] Test torch from 1.9 to 1.13

* [FIX] Ignore 'zip()' without an explicit 'strict=' parameter

* [REF] Make GGNvps contiguous before flattening and concatenation

* [CI] Unambiguously specify tested torch versions

* [REF] Import _grad_input_padding from torch for torch<1.13

* [FIX] Exception handling for Hessians of linear functions

* [REF] Same `_grad_input_padding` import strategy for conv_transpose

* [FIX] Merge conflict

* [CI] Ignore docstring check of _grad_input_padding

* [DOC] Add type annotation, remove unused import

* [DOC] Add type annotation for output
Adds a tutorial for BackPACK's `retain_graph` option. It shows
how to distribute the GGN diagonal computation of an auto-
encoder architecture over multiple backward passes to reduce
peak memory. 

This use case recently came up in a discussion with @wiseodd
on Laplace approximations for auto-encoders (or any large
output neural network with square loss).

* [ADD] Prototype of `retain_graph` example

* [DOC] Add comments to retain_graph example

* [REF] Improve comments

* [REF] Improve title format
* [REQ] Remove upper version restrictions for `torch` and `torchvision`

* [REQ] Bump python to 3.8+

* [REF] Replace `Tensor.symeig` with `torch.linalg.eigh`

* [CI] Replace `python3.7` with `python3.8`

* [REF] Try fixing syntax for `flake8` in `setup.cfg`

* [TEST] Skip double-backward of LSTM for PyTorch2.0.1

See pytorch/pytorch#99413

* [FIX] flake8

* [TEST] Skip `jac_mat_prod` for LSTM in PyTorch2.0.1

double-backward not supported pytorch/pytorch#99413

* [CI] Use python3.8 in RTD build

* [CI] Skip LSTM for PyTorch2.0.1 in DiagGGN tests

* [FIX] Imports

* [FIX] Turn off MKLDNN in RNN example

---------

Co-authored-by: Felix Dangel <[email protected]>
Co-authored-by: Felix Dangel <[email protected]>
@f-dangel f-dangel merged commit 1ebfb40 into master Jul 12, 2023
42 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants