Skip to content

Commit

Permalink
define == for Bijector
Browse files Browse the repository at this point in the history
  • Loading branch information
mohamed82008 committed May 13, 2020
1 parent 6760cfb commit 8517009
Show file tree
Hide file tree
Showing 15 changed files with 90 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/bijectors/composed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ end

isclosedform(b::Composed) = all(isclosedform, b.ts)
up1(b::Composed) = Composed(up1.(b.ts))
==(b1::Composed{A, N}, b2::Composed{A, N}) where {A, N} = all(b1.ts .== b2.ts)

"""
composel(ts::Bijector...)::Composed{<:Tuple}
Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/exp_log.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ struct Exp{N} <: Bijector{N} end
struct Log{N} <: Bijector{N} end
up1(::Exp{N}) where {N} = Exp{N + 1}()
up1(::Log{N}) where {N} = Log{N + 1}()
==(::Exp{N}, ::Exp{N}) where {N} = true
==(::Log{N}, ::Log{N}) where {N} = true

Exp() = Exp{0}()
Log() = Log{0}()
Expand Down
1 change: 1 addition & 0 deletions src/bijectors/logit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ function Logit(a, b)
Logit{0, T}(a, b)
end
up1(b::Logit{N, T}) where {N, T} = Logit{N + 1, T}(b.a, b.b)
==(b1::Logit{N}, b2::Logit{N}) where {N} = b1.a == b2.a && b1.b == b2.b

(b::Logit{0})(x::Real) = _logit(x, b.a, b.b)
(b::Logit{0})(x) = _logit.(x, b.a, b.b)
Expand Down
8 changes: 8 additions & 0 deletions src/bijectors/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ mutable struct InvertibleBatchNorm{T1,T2,T3} <: Bijector{1}
eps :: T3
mtm :: T3 # momentum
end
function ==(b1::InvertibleBatchNorm, b2::InvertibleBatchNorm)
return b1.b == b2.b &&
b1.logs == b2.logs &&
b1.m == b2.m &&
b1.v == b2.v &&
b1.eps == b2.eps &&
b1.mtm == b2.mtm
end

function InvertibleBatchNorm(
chs::Int;
Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/pd.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
struct PDBijector <: Bijector{2} end

==(::PDBijector, ::PDBijector) = true

# This function has custom adjoints defined for Tracker, Zygote and ReverseDiff.
# I couldn't find a mutation-free implementation that maintains TrackedArrays in Tracker
# and ReverseDiff, hence the need for custom adjoints.
Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/permute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ struct Permute{A} <: Bijector{1}
A::A
end

==(b1::Permute, b2::Permute) = b1.A == b2.A

function Permute(indices::AbstractVector{Int})
# construct a sparse-matrix for use in the multiplication
n = length(indices)
Expand Down
4 changes: 4 additions & 0 deletions src/bijectors/planar_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ mutable struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Real} <: Bijector{1}
b::T2
end

function ==(b1::PlanarLayer, b2::PlanarLayer)
return b1.w == b2.w && b1.u == b2.u && b1.b == b2.b
end

function get_u_hat(u, w)
# To preserve invertibility
x = w' * u
Expand Down
4 changes: 4 additions & 0 deletions src/bijectors/radial_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ mutable struct RadialLayer{T1 <: Real, T2 <: AbstractVector{<:Real}} <: Bijector
z_0::T2
end

function ==(b1::RadialLayer, b2::RadialLayer)
return b1.α_ == b2.α_ && b1.β == b2.β && b1.z_0 == b2.z_0
end

function RadialLayer(dims::Int, container=Array)
α_ = randn()
β = randn()
Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/scale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ struct Scale{T, N} <: Bijector{N}
a::T
end

==(b1::Scale{<:Any, N}, b2::Scale{<:Any, N}) where {N} = b1.a == b2.a

function Scale(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D
return Scale{typeof(a), D}(a)
end
Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/shift.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ struct Shift{T, N} <: Bijector{N}
a::T
end

==(b1::Shift{<:Any, N}, b2::Shift{<:Any, N}) where {N} = b1.a == b2.a

function Shift(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D
return Shift{typeof(a), D}(a)
end
Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/simplex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ function SimplexBijector{N}() where {N}
end
end

==(::SimplexBijector{N, T}, ::SimplexBijector{N, T}) where {N, T} = true

(b::SimplexBijector{1})(x::AbstractVector) = _simplex_bijector(x, b)
(b::SimplexBijector{1})(y::AbstractVector, x::AbstractVector) = _simplex_bijector!(y, x, b)
function _simplex_bijector(x::AbstractVector, b::SimplexBijector{1})
Expand Down
4 changes: 4 additions & 0 deletions src/bijectors/stacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ end
Stacked(bs, ranges::AbstractArray) = Stacked(bs, tuple(ranges...))
Stacked(bs) = Stacked(bs, tuple([i:i for i = 1:length(bs)]...))

function ==(b1::Stacked, b2::Stacked)
return all(b1.bs .== b2.bs) && all(b1.ranges .== b2.ranges)
end

isclosedform(b::Stacked) = all(isclosedform, b.bs)

stack(bs::Bijector{0}...) = Stacked(bs)
Expand Down
4 changes: 4 additions & 0 deletions src/bijectors/truncated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ function TruncatedBijector{N}(lb::T1, ub::T2) where {N, T1, T2}
end
up1(b::TruncatedBijector{N}) where {N} = TruncatedBijector{N + 1}(b.lb, b.ub)

function ==(b1::TruncatedBijector, b2::TruncatedBijector)
return b1.lb == b2.lb && b1.ub == b2.ub
end

function (b::TruncatedBijector{0})(x::Real)
a, b = b.lb, b.ub
truncated_link(_clamp(x, a, b), a, b)
Expand Down
7 changes: 6 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Base: inv,
import Base: inv, , ==

import Random: AbstractRNG
import Distributions: logpdf, rand, rand!, _rand!, _logpdf
Expand Down Expand Up @@ -41,6 +41,9 @@ dimension(b::Bijector{N}) where {N} = N
dimension(b::Type{<:Bijector{N}}) where {N} = N

Broadcast.broadcastable(b::Bijector) = Ref(b)
function ==(b1::Bijector, b2::Bijector)
return false
end

"""
isclosedform(b::Bijector)::bool
Expand Down Expand Up @@ -70,6 +73,7 @@ up1(b::Inverse) = Inverse(up1(b.orig))

inv(b::Bijector) = Inverse(b)
inv(ib::Inverse{<:Bijector}) = ib.orig
==(b1::Inverse{<:Bijector}, b2::Inverse{<:Bijector}) = b1.orig == b2.orig

"""
logabsdetjac(b::Bijector, x)
Expand Down Expand Up @@ -111,6 +115,7 @@ struct Identity{N} <: Bijector{N} end
(::Identity)(x) = copy(x)
inv(b::Identity) = b
up1(::Identity{N}) where {N} = Identity{N + 1}()
==(b1::Identity{N}, b2::Identity{N}) where {N} = true

logabsdetjac(::Identity{0}, x::Real) = zero(eltype(x))
@generated function logabsdetjac(
Expand Down
48 changes: 46 additions & 2 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using ForwardDiff
using Tracker

using Bijectors
using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, ADBijector
using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector

Random.seed!(123)

Expand Down Expand Up @@ -599,7 +599,7 @@ end
@test res.rv == [exp(x[1]), log(x[2]), x[3] + 5.0]
@test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:3])
@test res.logabsdetjac == logabsdetjac(sb, x)


# TODO: change when we have dimensionality in the type
sb = @inferred Stacked((Bijectors.Exp(), Bijectors.SimplexBijector()), [1:1, 2:3])
Expand Down Expand Up @@ -735,4 +735,48 @@ end
@test Δ_forwarddiff Δ_tracker
end

@testset "Equality" begin
bs = [
Identity{0}(),
Identity{1}(),
Identity{2}(),
Exp{0}(),
Exp{1}(),
Exp{2}(),
Log{0}(),
Log{1}(),
Log{2}(),
Scale(2.0),
Scale(3.0),
Shift(2.0),
Shift(3.0),
Logit(1.0, 2.0),
Logit(1.0, 3.0),
Logit(2.0, 3.0),
Logit(0.0, 2.0),
InvertibleBatchNorm(2),
InvertibleBatchNorm(3),
PDBijector(),
Permute([1.0, 2.0, 3.0]),
Permute([2.0, 3.0, 4.0]),
PlanarLayer(2),
PlanarLayer(3),
RadialLayer(2),
RadialLayer(3),
SimplexBijector(),
Stacked((Exp{0}(), Log{0}())),
Stacked((Log{0}(), Exp{0}())),
TruncatedBijector(1.0, 2.0),
TruncatedBijector(1.0, 3.0),
TruncatedBijector(0.0, 2.0),
]
for i in 1:length(bs), j in 1:length(bs)
if i == j
@test bs[i] == bs[j]
else
@test bs[i] != bs[j]
end
end
end

include("norm_flows.jl")

0 comments on commit 8517009

Please sign in to comment.