Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of VecCorrBijector #246

Merged
merged 179 commits into from
Jun 12, 2023
Merged
Show file tree
Hide file tree
Changes from 175 commits
Commits
Show all changes
179 commits
Select commit Hold shift + click to select a range
79b92c9
initial work on VecCorrBijector
torfjelde Feb 6, 2023
aa2fe61
added some tests for CorrBijector, and fixed implementation for VecCo…
torfjelde Feb 6, 2023
8d23094
improved tests and are now using integer sqrt and division
torfjelde Feb 6, 2023
a35e36f
moved things around a bit
torfjelde Feb 12, 2023
8cadf69
added chainrule for ReverseDiff
torfjelde Feb 13, 2023
eaf5324
some fixes for AD
torfjelde Feb 13, 2023
36ffbdb
added some TODOs
torfjelde Feb 13, 2023
62ae1ac
Update src/bijectors/corr.jl
torfjelde Mar 24, 2023
3f25a8b
define bijectors for `LKJ` and `LKJCholesky`
harisorgn Apr 4, 2023
e1567c3
add `TransformedDistribution` constructor
harisorgn Apr 6, 2023
8d07e34
define `logpdf` for `LKJ` & `LKJCholesky`
harisorgn Apr 6, 2023
9a59a9f
define `rand` for `LKJ` & `LKJCholesky`
harisorgn Apr 6, 2023
f15ad85
add util to extract Cholesky factor
harisorgn Apr 6, 2023
53e78f3
TYPO: capitalize matrix
harisorgn Apr 6, 2023
ec7d20e
add util to convert `Vector` index
harisorgn Apr 6, 2023
2ed00f4
add `VecTriBijector`s for `LKJCholesky`
harisorgn Apr 6, 2023
07555fc
TYPO: capitilize matrix
harisorgn Apr 6, 2023
a75cabc
add `LKJCholesky` link for `UpperTriangular`
harisorgn Apr 6, 2023
844b07e
add `LKJCholesky` link for `LowerTriangular`
harisorgn Apr 6, 2023
792cfe9
TYPO: capitalize matrix
harisorgn Apr 6, 2023
8f0886b
add `LKJCholesky` inverse link to `UpperTriangular`
harisorgn Apr 6, 2023
35f1c03
rename `_logabsdetjac_chol_lkj`
harisorgn Apr 6, 2023
9d55829
dispatch `_logabsdetjac_inv_corr` for `::Vector`
harisorgn Apr 6, 2023
adf10ad
add logabsdetjac for inverse link of `LKJCholesky`
harisorgn Apr 6, 2023
03a55b2
add tests for `VecTriBijector`s
harisorgn Apr 6, 2023
1059569
add `rrule` for LKJ(Cholesky) link function
harisorgn Apr 6, 2023
222eb6e
Merge branch 'torfjelde/vec-corr' into ho/vec-lkj-cholesky
harisorgn Apr 6, 2023
7f5d0fc
Merge pull request #1 from harisorgn/ho/vec-lkj-cholesky
harisorgn Apr 6, 2023
ad080ea
use `transpose` in link for `::LowerTriangular'
harisorgn Apr 11, 2023
6e1a5b1
add `Tracker` support for inverse link
harisorgn Apr 12, 2023
5fd0a65
better utility function call
harisorgn Apr 12, 2023
b38acda
use function barrier properly for type stability
harisorgn Apr 12, 2023
424f8ca
account for difference in support dimensions
harisorgn Apr 13, 2023
b749d37
fix indexing in Jacobian of `VecCorrBijector`
harisorgn Apr 13, 2023
7b1f74d
add `_logabsdetjac_dist` for `::LKJCholesky`
harisorgn Apr 13, 2023
75c605b
replace function composition for proper barrier
harisorgn Apr 13, 2023
a7a6c05
add util convert `Transpose -> Matrix` for type stability
harisorgn Apr 13, 2023
09c35b6
add `LKJCholesky` Jacobian+type tests
harisorgn Apr 13, 2023
2ad5038
fix `logabsdetjac` for inverse link
harisorgn Apr 14, 2023
f5be4e2
use `Cholesky` constructor compatible with `v1.6`
harisorgn Apr 14, 2023
10d9345
add empty line
harisorgn Apr 17, 2023
bcf32a3
fix `rrule` for link function
harisorgn Apr 17, 2023
7f4551f
add link `rrule` test
harisorgn Apr 17, 2023
dc2c856
add `rrule` for inverse link
harisorgn Apr 17, 2023
87bc3ca
remove TODO
harisorgn Apr 17, 2023
bfb7c15
add inverse link `rrule` test
harisorgn Apr 17, 2023
20ab3b4
Update src/bijectors/corr.jl
harisorgn Apr 17, 2023
7bb37e0
add link `rrule` for `LowerTriangular`
harisorgn Apr 18, 2023
3e2c7a8
add `LowerTriangular` chainrule test
harisorgn Apr 18, 2023
adba9e8
Update src/bijectors/corr.jl
harisorgn Apr 18, 2023
ec18964
remove unused util
harisorgn Apr 18, 2023
37c38ab
use `similar` instead of `zeros`
harisorgn Apr 18, 2023
8fd13b0
update comments
harisorgn Apr 18, 2023
56cc43f
remove old comment
harisorgn Apr 18, 2023
8ee086a
minimize zero-setting operations in inverse link
harisorgn Apr 18, 2023
837b49c
minimize zero-setting operations in `rrule`
harisorgn Apr 18, 2023
0c3aa39
add parametric `Val` type to `VecCorrBijector`
harisorgn Apr 18, 2023
c1be272
update `VecCorrBijector` tests
harisorgn Apr 18, 2023
29fced6
use field value instead of `Val`-parametric type
harisorgn Apr 18, 2023
74d6edb
update tests with new `VecCorrBijector`
harisorgn Apr 18, 2023
4c27987
`using VecCorrBijector` in test utils
harisorgn Apr 18, 2023
9108c40
add `VecCorrBijector.mode` check
harisorgn Apr 18, 2023
24847cc
update `VecCorrBijector` docstring
harisorgn Apr 18, 2023
bd4de96
specialise `Zygote@adjoint` for `AbstractMatrix`
harisorgn Apr 18, 2023
65bfc42
`ReverseDiff` opt-in to `ChainRules`
harisorgn Apr 18, 2023
eca3411
empty lines format
harisorgn Apr 18, 2023
f02fd9b
add AD test for inverse link
harisorgn Apr 18, 2023
c90f7ac
include `VecCorrBijector` tests
harisorgn Apr 18, 2023
974efb5
remove broken flag for `Tracker`
harisorgn Apr 18, 2023
71fdae6
add roundtrip AD tests for `VecCorrBijector`
harisorgn Apr 18, 2023
6524fe4
remove wrong `ReverseDiff.@grad` for `pd_from_upper`
harisorgn Apr 19, 2023
5e4abae
add corrected `rrule` for `pd_from_upper`
harisorgn Apr 19, 2023
c547542
update AD tests
harisorgn Apr 19, 2023
0d599e8
remove `Tracker` from broken
harisorgn Apr 19, 2023
a1f16b6
update zero-filling in `Tracker` pullback
harisorgn Apr 25, 2023
8b4b0c7
fix `Zygote`
harisorgn Apr 25, 2023
890127f
merge lines - applying feedback suggestions
harisorgn May 4, 2023
fa13e27
`unthunk` in `pd_from_upper` rrule
harisorgn May 24, 2023
a36f2b6
split structs into `VecCorrBijector` and `VecCholeskyBijector`
harisorgn May 24, 2023
9690dd2
remove old `Zygote` adjoints
harisorgn May 24, 2023
8a67713
update tests
harisorgn May 24, 2023
37cfd90
fix `Union` in `@inferred` after splitting structs
harisorgn May 24, 2023
a3c7f57
remove `Tracker` tests as support is dropped
harisorgn May 24, 2023
df4d960
use `permutedims` instead of casting
harisorgn Jun 6, 2023
17f784f
remove `Union` in `@inferred`
harisorgn Jun 6, 2023
852573d
initial work on VecCorrBijector
torfjelde Feb 6, 2023
cea5f19
added some tests for CorrBijector, and fixed implementation for VecCo…
torfjelde Feb 6, 2023
89612cc
improved tests and are now using integer sqrt and division
torfjelde Feb 6, 2023
bc8f755
moved things around a bit
torfjelde Feb 12, 2023
9b3d7e9
added chainrule for ReverseDiff
torfjelde Feb 13, 2023
b1176d0
some fixes for AD
torfjelde Feb 13, 2023
f3a623f
added some TODOs
torfjelde Feb 13, 2023
d46e966
define bijectors for `LKJ` and `LKJCholesky`
harisorgn Apr 4, 2023
f210356
add `TransformedDistribution` constructor
harisorgn Apr 6, 2023
71e1017
define `logpdf` for `LKJ` & `LKJCholesky`
harisorgn Apr 6, 2023
37e649c
define `rand` for `LKJ` & `LKJCholesky`
harisorgn Apr 6, 2023
c09c5c8
add util to extract Cholesky factor
harisorgn Apr 6, 2023
2a514c8
TYPO: capitalize matrix
harisorgn Apr 6, 2023
6596c9e
add util to convert `Vector` index
harisorgn Apr 6, 2023
6123d6d
add `VecTriBijector`s for `LKJCholesky`
harisorgn Apr 6, 2023
791f764
TYPO: capitilize matrix
harisorgn Apr 6, 2023
f47cdac
add `LKJCholesky` link for `UpperTriangular`
harisorgn Apr 6, 2023
959b836
add `LKJCholesky` link for `LowerTriangular`
harisorgn Apr 6, 2023
a8ccaa1
TYPO: capitalize matrix
harisorgn Apr 6, 2023
82bf085
add `LKJCholesky` inverse link to `UpperTriangular`
harisorgn Apr 6, 2023
597b6a1
rename `_logabsdetjac_chol_lkj`
harisorgn Apr 6, 2023
54dd86d
dispatch `_logabsdetjac_inv_corr` for `::Vector`
harisorgn Apr 6, 2023
eaf60f7
add logabsdetjac for inverse link of `LKJCholesky`
harisorgn Apr 6, 2023
861eef6
add tests for `VecTriBijector`s
harisorgn Apr 6, 2023
78b9999
add `rrule` for LKJ(Cholesky) link function
harisorgn Apr 6, 2023
5b4119a
use `transpose` in link for `::LowerTriangular'
harisorgn Apr 11, 2023
011534c
add `Tracker` support for inverse link
harisorgn Apr 12, 2023
ff61ef0
better utility function call
harisorgn Apr 12, 2023
a2ec603
use function barrier properly for type stability
harisorgn Apr 12, 2023
4c3a68b
account for difference in support dimensions
harisorgn Apr 13, 2023
6349546
fix indexing in Jacobian of `VecCorrBijector`
harisorgn Apr 13, 2023
e65a78b
add `_logabsdetjac_dist` for `::LKJCholesky`
harisorgn Apr 13, 2023
b6b7fa6
replace function composition for proper barrier
harisorgn Apr 13, 2023
fd24602
add util convert `Transpose -> Matrix` for type stability
harisorgn Apr 13, 2023
1cd62d1
add `LKJCholesky` Jacobian+type tests
harisorgn Apr 13, 2023
f437e68
fix `logabsdetjac` for inverse link
harisorgn Apr 14, 2023
85397e8
use `Cholesky` constructor compatible with `v1.6`
harisorgn Apr 14, 2023
aa5685a
add empty line
harisorgn Apr 17, 2023
df264d6
fix `rrule` for link function
harisorgn Apr 17, 2023
599cb66
add link `rrule` test
harisorgn Apr 17, 2023
9cd42c0
add `rrule` for inverse link
harisorgn Apr 17, 2023
9de4734
remove TODO
harisorgn Apr 17, 2023
befa1cc
add inverse link `rrule` test
harisorgn Apr 17, 2023
6ba1c1f
Update src/bijectors/corr.jl
harisorgn Apr 17, 2023
79ad5f8
add link `rrule` for `LowerTriangular`
harisorgn Apr 18, 2023
19e8843
add `LowerTriangular` chainrule test
harisorgn Apr 18, 2023
4216dbd
Update src/bijectors/corr.jl
harisorgn Apr 18, 2023
e70430f
remove unused util
harisorgn Apr 18, 2023
2caba1c
use `similar` instead of `zeros`
harisorgn Apr 18, 2023
561f6b1
update comments
harisorgn Apr 18, 2023
69f5daa
remove old comment
harisorgn Apr 18, 2023
ca9807e
minimize zero-setting operations in inverse link
harisorgn Apr 18, 2023
1883b36
minimize zero-setting operations in `rrule`
harisorgn Apr 18, 2023
f84b329
add parametric `Val` type to `VecCorrBijector`
harisorgn Apr 18, 2023
2918463
update `VecCorrBijector` tests
harisorgn Apr 18, 2023
2c4920d
use field value instead of `Val`-parametric type
harisorgn Apr 18, 2023
1872bb6
update tests with new `VecCorrBijector`
harisorgn Apr 18, 2023
1250592
`using VecCorrBijector` in test utils
harisorgn Apr 18, 2023
66b4caa
add `VecCorrBijector.mode` check
harisorgn Apr 18, 2023
c5cb535
update `VecCorrBijector` docstring
harisorgn Apr 18, 2023
8a06239
specialise `Zygote@adjoint` for `AbstractMatrix`
harisorgn Apr 18, 2023
44b3b9f
`ReverseDiff` opt-in to `ChainRules`
harisorgn Apr 18, 2023
a5d601d
empty lines format
harisorgn Apr 18, 2023
8783271
add AD test for inverse link
harisorgn Apr 18, 2023
a197076
include `VecCorrBijector` tests
harisorgn Apr 18, 2023
7b9d1b2
remove broken flag for `Tracker`
harisorgn Apr 18, 2023
5d1a7b8
add roundtrip AD tests for `VecCorrBijector`
harisorgn Apr 18, 2023
a0d5e52
remove wrong `ReverseDiff.@grad` for `pd_from_upper`
harisorgn Apr 19, 2023
bd0efff
add corrected `rrule` for `pd_from_upper`
harisorgn Apr 19, 2023
e3314a4
update AD tests
harisorgn Apr 19, 2023
c34ad47
remove `Tracker` from broken
harisorgn Apr 19, 2023
e154061
update zero-filling in `Tracker` pullback
harisorgn Apr 25, 2023
cffb616
fix `Zygote`
harisorgn Apr 25, 2023
c13fce6
merge lines - applying feedback suggestions
harisorgn May 4, 2023
dfeb71e
`unthunk` in `pd_from_upper` rrule
harisorgn May 24, 2023
5210437
split structs into `VecCorrBijector` and `VecCholeskyBijector`
harisorgn May 24, 2023
25a70b4
remove old `Zygote` adjoints
harisorgn May 24, 2023
b056fdd
update tests
harisorgn May 24, 2023
33a8a29
fix `Union` in `@inferred` after splitting structs
harisorgn May 24, 2023
bfa448b
remove `Tracker` tests as support is dropped
harisorgn May 24, 2023
96b90e6
use `permutedims` instead of casting
harisorgn Jun 6, 2023
48edf87
remove `Union` in `@inferred`
harisorgn Jun 6, 2023
a25b36f
Merge branch 'torfjelde/vec-corr' of https://github.com/TuringLang/Bi…
harisorgn Jun 6, 2023
159ddb6
wrap matrix in `Hermitian` before `cholesky`
harisorgn Jun 6, 2023
1bfb2ee
Merge branch 'master' into torfjelde/vec-corr
torfjelde Jun 6, 2023
9c3dec8
add hacky dispatch for `cholesky_factor` and `ReverseDiff`
harisorgn Jun 8, 2023
980660a
Merge branch 'torfjelde/vec-corr' of https://github.com/TuringLang/Bi…
harisorgn Jun 8, 2023
87a6fac
import `cholesky_factor` in ReverseDiff module for hacky dispatch
harisorgn Jun 8, 2023
1d8999f
only use hacky `cholesky_factor` in versions before fix
harisorgn Jun 8, 2023
424607d
change `LKJCholesky` shape to avoid stochastic test failures
harisorgn Jun 8, 2023
be5c1c5
Merge branch 'master' into torfjelde/vec-corr
yebai Jun 10, 2023
6aeebbf
remove old TODOs
harisorgn Jun 12, 2023
62ca234
add explicit zero-filling in link for `CorrBijector`
harisorgn Jun 12, 2023
f439682
Merge branch 'torfjelde/vec-corr' of https://github.com/TuringLang/Bi…
harisorgn Jun 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ using Reexport, Requires
using LinearAlgebra
using MappedArrays
using Base.Iterators: drop
using LinearAlgebra: AbstractTriangular
using LinearAlgebra: AbstractTriangular, Hermitian

using InverseFunctions: InverseFunctions

Expand Down Expand Up @@ -145,6 +145,9 @@ function _logabsdetjac_dist(d::MatrixDistribution, x::AbstractVector{<:AbstractM
return logabsdetjac.((bijector(d),), x)
end

_logabsdetjac_dist(d::LKJCholesky, x::Cholesky) = logabsdetjac(bijector(d), x)
_logabsdetjac_dist(d::LKJCholesky, x::AbstractVector) = logabsdetjac.((bijector(d),), x)

function logpdf_with_trans(d::Distribution, x, transform::Bool)
if ispd(d)
return pd_logpdf_with_trans(d, x, transform)
Expand Down
Loading