-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Define == for Bijector #106
Conversation
Uhhh a heads up: most of these should just work because Julia, e.g. julia> using Bijectors: Exp, Log, Logit
julia> Exp() == Exp()
true
julia> Logit(0., 1.) == Logit(1., 2.)
false
julia> Logit(0., 1.) == Logit(0., 1.)
true
julia> b1 = Exp() ∘ Logit(0., 1.); b2 = Exp() ∘ Logit(0., 1.)
Composed{Tuple{Logit{0,Float64},Exp{0}},0}((Logit{0,Float64}(0.0, 1.0), Exp{0}()))
julia> b1 == b2
true
julia> b1 = Exp() ∘ Logit(0., 1.); b2 = Exp() ∘ Logit(0., 2.)
Composed{Tuple{Logit{0,Float64},Exp{0}},0}((Logit{0,Float64}(0.0, 2.0), Exp{0}()))
julia> b1 == b2
false Are there any particular ones for which it doesn't work? |
Probably only needed for mutable ones, e.g. |
Hmm let me see. |
Related: JuliaLang/julia#4648 Immutable structs are compared by recursively applying |
|
Curious. So from the issue it seems like julia> b1 = Permute(2, 1 => 2, 2 => 1)
Permute{SparseArrays.SparseMatrixCSC{Float64,Int64}}(
[2, 1] = 1.0
[1, 2] = 1.0)
julia> b2 = Permute(2, 1 => 2, 2 => 1)
Permute{SparseArrays.SparseMatrixCSC{Float64,Int64}}(
[2, 1] = 1.0
[1, 2] = 1.0)
julia> b1 == b2
false
julia> b1.A == b2.A
true
julia> isequal(b1.A) == isequal(b2.A)
false Not clear to me why this happens. Might be an issue with |
I think the recursive "rule" doesn't hold for arrays in general. julia> struct T
a
end
julia> T(ones(2)) == T(ones(2))
false |
So we need |
Ah, yeah seems like you're right 👍 |
@torfjelde I think we need julia> using Bijectors, ReverseDiff
julia> Bijectors.Logit(0, 1) == Bijectors.Logit(ReverseDiff.track(0), ReverseDiff.track(1))
false This is not super important as the AD number types showing up in the bijector is already a sign that the support of the distribution is changing so we might as well get |
Co-authored-by: David Widmann <[email protected]>
This PR defines == for bijectors. This is needed to check that the support of a distribution hasn't changed in DynamicPPL.