diff --git a/src/bijectors/composed.jl b/src/bijectors/composed.jl index 43b319dd..0d1c85d7 100644 --- a/src/bijectors/composed.jl +++ b/src/bijectors/composed.jl @@ -94,6 +94,13 @@ end isclosedform(b::Composed) = all(isclosedform, b.ts) up1(b::Composed) = Composed(up1.(b.ts)) +function Base.:(==)(b1::Composed{<:Any, N}, b2::Composed{<:Any, N}) where {N} + ts1, ts2 = b1.ts, b2.ts + if !(ts1 isa Tuple && ts2 isa Tuple || ts1 isa Vector && ts2 isa Vector) + return false + end + return all(ts1 .== ts2) +end """ composel(ts::Bijector...)::Composed{<:Tuple} diff --git a/src/bijectors/logit.jl b/src/bijectors/logit.jl index f0f6386d..f342b837 100644 --- a/src/bijectors/logit.jl +++ b/src/bijectors/logit.jl @@ -12,6 +12,8 @@ function Logit(a, b) Logit{0, T}(a, b) end up1(b::Logit{N, T}) where {N, T} = Logit{N + 1, T}(b.a, b.b) +# For equality of Logit with Float64 fields to one with Duals +Base.:(==)(b1::Logit, b2::Logit) = b1.a == b2.a && b1.b == b2.b (b::Logit{0})(x::Real) = _logit(x, b.a, b.b) (b::Logit{0})(x) = _logit.(x, b.a, b.b) @@ -36,4 +38,4 @@ logabsdetjac(b::Logit{2}, x::AbstractMatrix) = sum(logit_logabsdetjac.(x, b.a, b logabsdetjac(b::Logit{2}, x::AbstractArray{<:AbstractMatrix}) = map(x) do x logabsdetjac(b, x) end -logit_logabsdetjac(x, a, b) = -log((x - a) * (b - x) / (b - a)) \ No newline at end of file +logit_logabsdetjac(x, a, b) = -log((x - a) * (b - x) / (b - a)) diff --git a/src/bijectors/normalise.jl b/src/bijectors/normalise.jl index 4bb8ded9..300c46e2 100644 --- a/src/bijectors/normalise.jl +++ b/src/bijectors/normalise.jl @@ -14,6 +14,14 @@ mutable struct InvertibleBatchNorm{T1,T2,T3} <: Bijector{1} eps :: T3 mtm :: T3 # momentum end +function Base.:(==)(b1::InvertibleBatchNorm, b2::InvertibleBatchNorm) + return b1.b == b2.b && + b1.logs == b2.logs && + b1.m == b2.m && + b1.v == b2.v && + b1.eps == b2.eps && + b1.mtm == b2.mtm +end function InvertibleBatchNorm( chs::Int; diff --git a/src/bijectors/permute.jl b/src/bijectors/permute.jl index 9fc86bfc..d4fdef7b 100644 --- a/src/bijectors/permute.jl +++ b/src/bijectors/permute.jl @@ -85,6 +85,8 @@ struct Permute{A} <: Bijector{1} A::A end +Base.:(==)(b1::Permute, b2::Permute) = b1.A == b2.A + function Permute(indices::AbstractVector{Int}) # construct a sparse-matrix for use in the multiplication n = length(indices) diff --git a/src/bijectors/planar_layer.jl b/src/bijectors/planar_layer.jl index 03063fbd..a38a6be0 100644 --- a/src/bijectors/planar_layer.jl +++ b/src/bijectors/planar_layer.jl @@ -20,6 +20,9 @@ mutable struct PlanarLayer{T1<:AbstractVector{<:Real}, T2<:Real} <: Bijector{1} u::T1 b::T2 end +function Base.:(==)(b1::PlanarLayer, b2::PlanarLayer) + return b1.w == b2.w && b1.u == b2.u && b1.b == b2.b +end function get_u_hat(u, w) # To preserve invertibility diff --git a/src/bijectors/radial_layer.jl b/src/bijectors/radial_layer.jl index 5bade504..2f4367b1 100644 --- a/src/bijectors/radial_layer.jl +++ b/src/bijectors/radial_layer.jl @@ -18,6 +18,9 @@ mutable struct RadialLayer{T1 <: Real, T2 <: AbstractVector{<:Real}} <: Bijector β::T1 z_0::T2 end +function Base.:(==)(b1::RadialLayer, b2::RadialLayer) + return b1.α_ == b2.α_ && b1.β == b2.β && b1.z_0 == b2.z_0 +end function RadialLayer(dims::Int, container=Array) α_ = randn() diff --git a/src/bijectors/scale.jl b/src/bijectors/scale.jl index 55305796..5d728a05 100644 --- a/src/bijectors/scale.jl +++ b/src/bijectors/scale.jl @@ -2,6 +2,8 @@ struct Scale{T, N} <: Bijector{N} a::T end +Base.:(==)(b1::Scale{<:Any, N}, b2::Scale{<:Any, N}) where {N} = b1.a == b2.a + function Scale(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D return Scale{typeof(a), D}(a) end diff --git a/src/bijectors/shift.jl b/src/bijectors/shift.jl index 3017ff99..ab975bb5 100644 --- a/src/bijectors/shift.jl +++ b/src/bijectors/shift.jl @@ -5,6 +5,8 @@ struct Shift{T, N} <: Bijector{N} a::T end +Base.:(==)(b1::Shift{<:Any, N}, b2::Shift{<:Any, N}) where {N} = b1.a == b2.a + function Shift(a::Union{Real,AbstractArray}; dim::Val{D} = Val(ndims(a))) where D return Shift{typeof(a), D}(a) end diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 37397d2d..d5b3faf6 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -44,6 +44,14 @@ end Stacked(bs, ranges::AbstractArray) = Stacked(bs, tuple(ranges...)) Stacked(bs) = Stacked(bs, tuple([i:i for i = 1:length(bs)]...)) +function Base.:(==)(b1::Stacked, b2::Stacked) + bs1, bs2 = b1.bs, b2.bs + if !(bs1 isa Tuple && bs2 isa Tuple || bs1 isa Vector && bs2 isa Vector) + return false + end + return all(bs1 .== bs2) && all(b1.ranges .== b2.ranges) +end + isclosedform(b::Stacked) = all(isclosedform, b.bs) stack(bs::Bijector{0}...) = Stacked(bs) diff --git a/src/bijectors/truncated.jl b/src/bijectors/truncated.jl index 2cb95e6b..e403f988 100644 --- a/src/bijectors/truncated.jl +++ b/src/bijectors/truncated.jl @@ -11,6 +11,10 @@ function TruncatedBijector{N}(lb::T1, ub::T2) where {N, T1, T2} end up1(b::TruncatedBijector{N}) where {N} = TruncatedBijector{N + 1}(b.lb, b.ub) +function Base.:(==)(b1::TruncatedBijector, b2::TruncatedBijector) + return b1.lb == b2.lb && b1.ub == b2.ub +end + function (b::TruncatedBijector{0})(x::Real) a, b = b.lb, b.ub truncated_link(_clamp(x, a, b), a, b) diff --git a/src/interface.jl b/src/interface.jl index a5ec2754..656b3912 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -70,6 +70,7 @@ up1(b::Inverse) = Inverse(up1(b.orig)) inv(b::Bijector) = Inverse(b) inv(ib::Inverse{<:Bijector}) = ib.orig +Base.:(==)(b1::Inverse{<:Bijector}, b2::Inverse{<:Bijector}) = b1.orig == b2.orig """ logabsdetjac(b::Bijector, x) diff --git a/test/interface.jl b/test/interface.jl index 7000f1da..31b0eadc 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -5,7 +5,7 @@ using ForwardDiff using Tracker using Bijectors -using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, ADBijector +using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector Random.seed!(123) @@ -599,7 +599,7 @@ end @test res.rv == [exp(x[1]), log(x[2]), x[3] + 5.0] @test logabsdetjac(sb, x) == sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i = 1:3]) @test res.logabsdetjac == logabsdetjac(sb, x) - + # TODO: change when we have dimensionality in the type sb = @inferred Stacked((Bijectors.Exp(), Bijectors.SimplexBijector()), [1:1, 2:3]) @@ -735,4 +735,59 @@ end @test Δ_forwarddiff ≈ Δ_tracker end +@testset "Equality" begin + bs = [ + Identity{0}(), + Identity{1}(), + Identity{2}(), + Exp{0}(), + Exp{1}(), + Exp{2}(), + Log{0}(), + Log{1}(), + Log{2}(), + Scale(2.0), + Scale(3.0), + Scale(rand(2,2)), + Scale(rand(2,2)), + Shift(2.0), + Shift(3.0), + Shift(rand(2)), + Shift(rand(2)), + Logit(1.0, 2.0), + Logit(1.0, 3.0), + Logit(2.0, 3.0), + Logit(0.0, 2.0), + InvertibleBatchNorm(2), + InvertibleBatchNorm(3), + PDBijector(), + Permute([1.0, 2.0, 3.0]), + Permute([2.0, 3.0, 4.0]), + PlanarLayer(2), + PlanarLayer(3), + RadialLayer(2), + RadialLayer(3), + SimplexBijector(), + Stacked((Exp{0}(), Log{0}())), + Stacked((Log{0}(), Exp{0}())), + Stacked([Exp{0}(), Log{0}()]), + Stacked([Log{0}(), Exp{0}()]), + Composed((Exp{0}(), Log{0}())), + Composed((Log{0}(), Exp{0}())), + Composed([Exp{0}(), Log{0}()]), + Composed([Log{0}(), Exp{0}()]), + TruncatedBijector(1.0, 2.0), + TruncatedBijector(1.0, 3.0), + TruncatedBijector(0.0, 2.0), + ] + for i in 1:length(bs), j in 1:length(bs) + if i == j + @test bs[i] == deepcopy(bs[j]) + @test inv(bs[i]) == inv(deepcopy(bs[j])) + else + @test bs[i] != bs[j] + end + end +end + include("norm_flows.jl") \ No newline at end of file