Skip to content

Commit

Permalink
Define == for Bijector (#106)
Browse files Browse the repository at this point in the history
* define == for bijectors and test

* add == method for Logit

* Update src/bijectors/logit.jl

Co-authored-by: David Widmann <[email protected]>

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
mohamed82008 and devmotion authored May 13, 2020
1 parent 6760cfb commit a763e3f
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 3 deletions.
7 changes: 7 additions & 0 deletions src/bijectors/composed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ end

isclosedform(b::Composed) = all(isclosedform, b.ts)
up1(b::Composed) = Composed(up1.(b.ts))
function Base.:(==)(b1::Composed{<:Any, N}, b2::Composed{<:Any, N}) where {N}
ts1, ts2 = b1.ts, b2.ts
if !(ts1 isa Tuple && ts2 isa Tuple || ts1 isa Vector && ts2 isa Vector)
return false
end
return all(ts1 .== ts2)
end

"""
composel(ts::Bijector...)::Composed{<:Tuple}
Expand Down
4 changes: 3 additions & 1 deletion src/bijectors/logit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ function Logit(a, b)
Logit{0, T}(a, b)
end
up1(b::Logit{N, T}) where {N, T} = Logit{N + 1, T}(b.a, b.b)
# For equality of Logit with Float64 fields to one with Duals
Base.:(==)(b1::Logit, b2::Logit) = b1.a == b2.a && b1.b == b2.b

(b::Logit{0})(x::Real) = _logit(x, b.a, b.b)
(b::Logit{0})(x) = _logit.(x, b.a, b.b)
Expand All @@ -36,4 +38,4 @@ logabsdetjac(b::Logit{2}, x::AbstractMatrix) = sum(logit_logabsdetjac.(x, b.a, b
logabsdetjac(b::Logit{2}, x::AbstractArray{<:AbstractMatrix}) = map(x) do x
logabsdetjac(b, x)
end
logit_logabsdetjac(x, a, b) = -log((x - a) * (b - x) / (b - a))
logit_logabsdetjac(x, a, b) = -log((x - a) * (b - x) / (b - a))
8 changes: 8 additions & 0 deletions src/bijectors/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ mutable struct InvertibleBatchNorm{T1,T2,T3} <: Bijector{1}
eps :: T3
mtm :: T3 # momentum
end
function Base.:(==)(b1::InvertibleBatchNorm, b2::InvertibleBatchNorm)
return b1.b == b2.b &&
b1.logs == b2.logs &&
b1.m == b2.m &&
b1.v == b2.v &&
b1.eps == b2.eps &&
b1.mtm == b2.mtm
end

function InvertibleBatchNorm(
chs::Int;
Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/permute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ struct Permute{A} <: Bijector{1}
A::A
end

Base.:(==)(b1::Permute, b2::Permute) = b1.A == b2.A

function Permute(indices::AbstractVector{Int})
# construct a sparse-matrix for use in the multiplication
n = length(indices)
Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/planar_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ mutable struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Real} <: Bijector{1}
u::T1
b::T2
end
function Base.:(==)(b1::PlanarLayer, b2::PlanarLayer)
return b1.w == b2.w && b1.u == b2.u && b1.b == b2.b
end

function get_u_hat(u, w)
# To preserve invertibility
Expand Down
3 changes: 3 additions & 0 deletions src/bijectors/radial_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ mutable struct RadialLayer{T1 <: Real, T2 <: AbstractVector{<:Real}} <: Bijector
β::T1
z_0::T2
end
function Base.:(==)(b1::RadialLayer, b2::RadialLayer)
return b1.α_ == b2.α_ && b1.β == b2.β && b1.z_0 == b2.z_0
end

function RadialLayer(dims::Int, container=Array)
α_ = randn()
Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ struct Scale{T, N} <: Bijector{N}
a::T
end

Base.:(==)(b1::Scale{<:Any, N}, b2::Scale{<:Any, N}) where {N} = b1.a == b2.a

function Scale(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D
return Scale{typeof(a), D}(a)
end
Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/shift.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ struct Shift{T, N} <: Bijector{N}
a::T
end

Base.:(==)(b1::Shift{<:Any, N}, b2::Shift{<:Any, N}) where {N} = b1.a == b2.a

function Shift(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D
return Shift{typeof(a), D}(a)
end
Expand Down
8 changes: 8 additions & 0 deletions src/bijectors/stacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,14 @@ end
Stacked(bs, ranges::AbstractArray) = Stacked(bs, tuple(ranges...))
Stacked(bs) = Stacked(bs, tuple([i:i for i = 1:length(bs)]...))

function Base.:(==)(b1::Stacked, b2::Stacked)
bs1, bs2 = b1.bs, b2.bs
if !(bs1 isa Tuple && bs2 isa Tuple || bs1 isa Vector && bs2 isa Vector)
return false
end
return all(bs1 .== bs2) && all(b1.ranges .== b2.ranges)
end

isclosedform(b::Stacked) = all(isclosedform, b.bs)

stack(bs::Bijector{0}...) = Stacked(bs)
Expand Down
4 changes: 4 additions & 0 deletions src/bijectors/truncated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ function TruncatedBijector{N}(lb::T1, ub::T2) where {N, T1, T2}
end
up1(b::TruncatedBijector{N}) where {N} = TruncatedBijector{N + 1}(b.lb, b.ub)

function Base.:(==)(b1::TruncatedBijector, b2::TruncatedBijector)
return b1.lb == b2.lb && b1.ub == b2.ub
end

function (b::TruncatedBijector{0})(x::Real)
a, b = b.lb, b.ub
truncated_link(_clamp(x, a, b), a, b)
Expand Down
1 change: 1 addition & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ up1(b::Inverse) = Inverse(up1(b.orig))

inv(b::Bijector) = Inverse(b)
inv(ib::Inverse{<:Bijector}) = ib.orig
Base.:(==)(b1::Inverse{<:Bijector}, b2::Inverse{<:Bijector}) = b1.orig == b2.orig

"""
logabsdetjac(b::Bijector, x)
Expand Down
59 changes: 57 additions & 2 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using ForwardDiff
using Tracker

using Bijectors
using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, ADBijector
using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector

Random.seed!(123)

Expand Down Expand Up @@ -599,7 +599,7 @@ end
@test res.rv == [exp(x[1]), log(x[2]), x[3] + 5.0]
@test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:3])
@test res.logabsdetjac == logabsdetjac(sb, x)


# TODO: change when we have dimensionality in the type
sb = @inferred Stacked((Bijectors.Exp(), Bijectors.SimplexBijector()), [1:1, 2:3])
Expand Down Expand Up @@ -735,4 +735,59 @@ end
@test Δ_forwarddiff Δ_tracker
end

@testset "Equality" begin
bs = [
Identity{0}(),
Identity{1}(),
Identity{2}(),
Exp{0}(),
Exp{1}(),
Exp{2}(),
Log{0}(),
Log{1}(),
Log{2}(),
Scale(2.0),
Scale(3.0),
Scale(rand(2,2)),
Scale(rand(2,2)),
Shift(2.0),
Shift(3.0),
Shift(rand(2)),
Shift(rand(2)),
Logit(1.0, 2.0),
Logit(1.0, 3.0),
Logit(2.0, 3.0),
Logit(0.0, 2.0),
InvertibleBatchNorm(2),
InvertibleBatchNorm(3),
PDBijector(),
Permute([1.0, 2.0, 3.0]),
Permute([2.0, 3.0, 4.0]),
PlanarLayer(2),
PlanarLayer(3),
RadialLayer(2),
RadialLayer(3),
SimplexBijector(),
Stacked((Exp{0}(), Log{0}())),
Stacked((Log{0}(), Exp{0}())),
Stacked([Exp{0}(), Log{0}()]),
Stacked([Log{0}(), Exp{0}()]),
Composed((Exp{0}(), Log{0}())),
Composed((Log{0}(), Exp{0}())),
Composed([Exp{0}(), Log{0}()]),
Composed([Log{0}(), Exp{0}()]),
TruncatedBijector(1.0, 2.0),
TruncatedBijector(1.0, 3.0),
TruncatedBijector(0.0, 2.0),
]
for i in 1:length(bs), j in 1:length(bs)
if i == j
@test bs[i] == deepcopy(bs[j])
@test inv(bs[i]) == inv(deepcopy(bs[j]))
else
@test bs[i] != bs[j]
end
end
end

include("norm_flows.jl")

0 comments on commit a763e3f

Please sign in to comment.