Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define == for Bijector #106

Merged
merged 3 commits into from
May 13, 2020
Merged

Define == for Bijector #106

merged 3 commits into from
May 13, 2020

Conversation

mohamed82008
Copy link
Member

This PR defines == for bijectors. This is needed to check that the support of a distribution hasn't changed in DynamicPPL.

@torfjelde
Copy link
Member

torfjelde commented May 13, 2020

Uhhh a heads up: most of these should just work because Julia, e.g.

julia> using Bijectors: Exp, Log, Logit

julia> Exp() == Exp()
true

julia> Logit(0., 1.) == Logit(1., 2.)
false

julia> Logit(0., 1.) == Logit(0., 1.)
true

julia> b1 = Exp()  Logit(0., 1.); b2 = Exp()  Logit(0., 1.)
Composed{Tuple{Logit{0,Float64},Exp{0}},0}((Logit{0,Float64}(0.0, 1.0), Exp{0}()))

julia> b1 == b2
true

julia> b1 = Exp()  Logit(0., 1.); b2 = Exp()  Logit(0., 2.)
Composed{Tuple{Logit{0,Float64},Exp{0}},0}((Logit{0,Float64}(0.0, 2.0), Exp{0}()))

julia> b1 == b2
false

Are there any particular ones for which it doesn't work?

@torfjelde
Copy link
Member

Probably only needed for mutable ones, e.g. PlanarLayer, RadialLayer and Normalize (I'm not entirely sure why we need these to be mutable though; could just make the mutations be performed elementwise using .= since the fields are all arrays anyways)

src/interface.jl Outdated Show resolved Hide resolved
@mohamed82008
Copy link
Member Author

Uhhh a heads up: most of these should just work because Julia, e.g.
Are there any particular ones for which it doesn't work?

Hmm let me see.

@torfjelde
Copy link
Member

Related: JuliaLang/julia#4648

Immutable structs are compared by recursively applying ==, but this is not the case for mutable (because issues: JuliaLang/julia#4648 (comment))

@mohamed82008
Copy link
Member Author

Permute seems to need an == method too.

@torfjelde
Copy link
Member

Curious. So from the issue it seems like == will recursively call isequal and, lo and behold,

julia> b1 = Permute(2, 1 => 2, 2 => 1)
Permute{SparseArrays.SparseMatrixCSC{Float64,Int64}}(
  [2, 1]  =  1.0
  [1, 2]  =  1.0)

julia> b2 = Permute(2, 1 => 2, 2 => 1)
Permute{SparseArrays.SparseMatrixCSC{Float64,Int64}}(
  [2, 1]  =  1.0
  [1, 2]  =  1.0)

julia> b1 == b2
false

julia> b1.A == b2.A
true

julia> isequal(b1.A) == isequal(b2.A)
false

Not clear to me why this happens. Might be an issue with SparseArrays.jl, because there's nothing in the docstring of isequal that tells me they shouldn't be equal.

@mohamed82008
Copy link
Member Author

I think the recursive "rule" doesn't hold for arrays in general.

julia> struct T
           a
       end

julia> T(ones(2)) == T(ones(2))
false

@mohamed82008
Copy link
Member Author

So we need == methods for any bijector that can have an array field.

@torfjelde
Copy link
Member

Ah, yeah seems like you're right 👍

@mohamed82008
Copy link
Member Author

@torfjelde I think we need == for all bijectors with numeric fields for this to return true:

julia> using Bijectors, ReverseDiff

julia> Bijectors.Logit(0, 1) == Bijectors.Logit(ReverseDiff.track(0), ReverseDiff.track(1))
false

This is not super important as the AD number types showing up in the bijector is already a sign that the support of the distribution is changing so we might as well get false here. But I would still like to see true in the above case.

src/bijectors/logit.jl Outdated Show resolved Hide resolved
Co-authored-by: David Widmann <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants