Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define == for Bijector #106

Merged
merged 3 commits into from
May 13, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/bijectors/composed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ end

isclosedform(b::Composed) = all(isclosedform, b.ts)
up1(b::Composed) = Composed(up1.(b.ts))
function Base.:(==)(b1::Composed{<:Any, N}, b2::Composed{<:Any, N}) where {N}
ts1, ts2 = b1.ts, b2.ts
if !(ts1 isa Tuple && ts2 isa Tuple || ts1 isa Vector && ts2 isa Vector)
return false
end
return all(ts1 .== ts2)
end

"""
composel(ts::Bijector...)::Composed{<:Tuple}
Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/logit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ function Logit(a, b)
Logit{0, T}(a, b)
end
up1(b::Logit{N, T}) where {N, T} = Logit{N + 1, T}(b.a, b.b)
# For equality of Logit with Float64 fields to one with Duals
Base.(==)(b1::Logit, b2::Logit) = b1.a == b2.a && b1.b == b2.b
mohamed82008 marked this conversation as resolved.
Show resolved Hide resolved

(b::Logit{0})(x::Real) = _logit(x, b.a, b.b)
(b::Logit{0})(x) = _logit.(x, b.a, b.b)
Expand Down
8 changes: 8 additions & 0 deletions src/bijectors/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ mutable struct InvertibleBatchNorm{T1,T2,T3} <: Bijector{1}
eps :: T3
mtm :: T3 # momentum
end
function Base.:(==)(b1::InvertibleBatchNorm, b2::InvertibleBatchNorm)
return b1.b == b2.b &&
b1.logs == b2.logs &&
b1.m == b2.m &&
b1.v == b2.v &&
b1.eps == b2.eps &&
b1.mtm == b2.mtm
end

function InvertibleBatchNorm(
chs::Int;
Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/permute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ struct Permute{A} <: Bijector{1}
A::A
end

Base.:(==)(b1::Permute, b2::Permute) = b1.A == b2.A

function Permute(indices::AbstractVector{Int})
# construct a sparse-matrix for use in the multiplication
n = length(indices)
Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/planar_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ mutable struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Real} <: Bijector{1}
u::T1
b::T2
end
function Base.:(==)(b1::PlanarLayer, b2::PlanarLayer)
return b1.w == b2.w && b1.u == b2.u && b1.b == b2.b
end

function get_u_hat(u, w)
# To preserve invertibility
Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/radial_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ mutable struct RadialLayer{T1 <: Real, T2 <: AbstractVector{<:Real}} <: Bijector
β::T1
z_0::T2
end
function Base.:(==)(b1::RadialLayer, b2::RadialLayer)
return b1.α_ == b2.α_ && b1.β == b2.β && b1.z_0 == b2.z_0
end

function RadialLayer(dims::Int, container=Array)
α_ = randn()
Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ struct Scale{T, N} <: Bijector{N}
a::T
end

Base.:(==)(b1::Scale{<:Any, N}, b2::Scale{<:Any, N}) where {N} = b1.a == b2.a

function Scale(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D
return Scale{typeof(a), D}(a)
end
Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/shift.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ struct Shift{T, N} <: Bijector{N}
a::T
end

Base.:(==)(b1::Shift{<:Any, N}, b2::Shift{<:Any, N}) where {N} = b1.a == b2.a

function Shift(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D
return Shift{typeof(a), D}(a)
end
Expand Down
8 changes: 8 additions & 0 deletions src/bijectors/stacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ end
Stacked(bs, ranges::AbstractArray) = Stacked(bs, tuple(ranges...))
Stacked(bs) = Stacked(bs, tuple([i:i for i = 1:length(bs)]...))

function Base.:(==)(b1::Stacked, b2::Stacked)
bs1, bs2 = b1.bs, b2.bs
if !(bs1 isa Tuple && bs2 isa Tuple || bs1 isa Vector && bs2 isa Vector)
return false
end
return all(bs1 .== bs2) && all(b1.ranges .== b2.ranges)
end

isclosedform(b::Stacked) = all(isclosedform, b.bs)

stack(bs::Bijector{0}...) = Stacked(bs)
Expand Down
4 changes: 4 additions & 0 deletions src/bijectors/truncated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ function TruncatedBijector{N}(lb::T1, ub::T2) where {N, T1, T2}
end
up1(b::TruncatedBijector{N}) where {N} = TruncatedBijector{N + 1}(b.lb, b.ub)

function Base.:(==)(b1::TruncatedBijector, b2::TruncatedBijector)
return b1.lb == b2.lb && b1.ub == b2.ub
end

function (b::TruncatedBijector{0})(x::Real)
a, b = b.lb, b.ub
truncated_link(_clamp(x, a, b), a, b)
Expand Down
1 change: 1 addition & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ up1(b::Inverse) = Inverse(up1(b.orig))

inv(b::Bijector) = Inverse(b)
inv(ib::Inverse{<:Bijector}) = ib.orig
Base.:(==)(b1::Inverse{<:Bijector}, b2::Inverse{<:Bijector}) = b1.orig == b2.orig

"""
logabsdetjac(b::Bijector, x)
Expand Down
59 changes: 57 additions & 2 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using ForwardDiff
using Tracker

using Bijectors
using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, ADBijector
using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector

Random.seed!(123)

Expand Down Expand Up @@ -599,7 +599,7 @@ end
@test res.rv == [exp(x[1]), log(x[2]), x[3] + 5.0]
@test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:3])
@test res.logabsdetjac == logabsdetjac(sb, x)


# TODO: change when we have dimensionality in the type
sb = @inferred Stacked((Bijectors.Exp(), Bijectors.SimplexBijector()), [1:1, 2:3])
Expand Down Expand Up @@ -735,4 +735,59 @@ end
@test Δ_forwarddiff ≈ Δ_tracker
end

@testset "Equality" begin
bs = [
Identity{0}(),
Identity{1}(),
Identity{2}(),
Exp{0}(),
Exp{1}(),
Exp{2}(),
Log{0}(),
Log{1}(),
Log{2}(),
Scale(2.0),
Scale(3.0),
Scale(rand(2,2)),
Scale(rand(2,2)),
Shift(2.0),
Shift(3.0),
Shift(rand(2)),
Shift(rand(2)),
Logit(1.0, 2.0),
Logit(1.0, 3.0),
Logit(2.0, 3.0),
Logit(0.0, 2.0),
InvertibleBatchNorm(2),
InvertibleBatchNorm(3),
PDBijector(),
Permute([1.0, 2.0, 3.0]),
Permute([2.0, 3.0, 4.0]),
PlanarLayer(2),
PlanarLayer(3),
RadialLayer(2),
RadialLayer(3),
SimplexBijector(),
Stacked((Exp{0}(), Log{0}())),
Stacked((Log{0}(), Exp{0}())),
Stacked([Exp{0}(), Log{0}()]),
Stacked([Log{0}(), Exp{0}()]),
Composed((Exp{0}(), Log{0}())),
Composed((Log{0}(), Exp{0}())),
Composed([Exp{0}(), Log{0}()]),
Composed([Log{0}(), Exp{0}()]),
TruncatedBijector(1.0, 2.0),
TruncatedBijector(1.0, 3.0),
TruncatedBijector(0.0, 2.0),
]
for i in 1:length(bs), j in 1:length(bs)
if i == j
@test bs[i] == deepcopy(bs[j])
@test inv(bs[i]) == inv(deepcopy(bs[j]))
else
@test bs[i] != bs[j]
end
end
end

include("norm_flows.jl")