Skip to content

Commit

Permalink
Fix broken tests (#170)
Browse files Browse the repository at this point in the history
* Fix AbstractGPs dep warnings

* Bump patch
  • Loading branch information
willtebbutt authored Apr 2, 2021
1 parent 73c47ea commit b4e2d20
Show file tree
Hide file tree
Showing 18 changed files with 98 additions and 106 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Stheno"
uuid = "8188c328-b5d6-583d-959b-9690869a5511"
version = "0.7.0"
version = "0.7.1"

[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Expand All @@ -19,7 +19,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractGPs = "0.2.21"
AbstractGPs = "0.2.25"
BlockArrays = "0.15"
ChainRulesCore = "0.9"
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24"
Expand Down
6 changes: 3 additions & 3 deletions docs/src/internals.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ As such, we have the following functions in addition to the AbstractGPs API impl

| Function | Brief description |
|:--------------------- |:---------------------- |
| `cov_diag(f, x)` | `diag(cov(f, x))` |
| `cov_diag(f, x, x′)` | `diag(cov(f, x, x′))` |
| `cov_diag(f, f′, x, x′)` | `diag(cov(f, f′, x, x′))` |
| `var(f, x)` | `diag(cov(f, x))` |
| `var(f, x, x′)` | `diag(cov(f, x, x′))` |
| `var(f, f′, x, x′)` | `diag(cov(f, f′, x, x′))` |

The second and third rows of the table only make sense when `length(x) == length(x′)`, of course.

Expand Down
9 changes: 6 additions & 3 deletions src/Stheno.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ module Stheno
import AbstractGPs:
mean,
cov,
cov_diag,
var,
mean_and_cov,
mean_and_cov_diag,
mean_and_var,
rand,
logpdf,
elbo,
dtc
dtc,
posterior,
approx_posterior

using MacroTools: @capture, combinedef, postwalk, splitdef

Expand Down Expand Up @@ -74,4 +76,5 @@ module Stheno
export wrap, BlockData, GPC, GPPPInput, @gppp
export elbo, dtc
export , select, stretch, periodic, shift
export cov_diag, mean_and_cov_diag
end # module
7 changes: 1 addition & 6 deletions src/abstract_gp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,10 @@
abstract type SthenoAbstractGP <: AbstractGP end

# TYPE-PIRACY
function AbstractGPs.cov_diag(f::GP, x::AbstractVector, x′::AbstractVector)
function var(f::GP, x::AbstractVector, x′::AbstractVector)
return kernelmatrix_diag(f.kernel, x, x′)
end

# Implement some of the AbstractGPs API for all of the GPs in this package.
mean_and_cov(f::SthenoAbstractGP, x::AbstractVector) = (mean(f, x), cov(f, x))

mean_and_cov_diag(f::SthenoAbstractGP, x::AbstractVector) = (mean(f, x), cov_diag(f, x))

# Ensure that this package gets to handle the covariance between its own GPs.
# AbstractGPs doesn't support this in general because it's unclear how it ought to be
# implemented, but we have a clear way to implement it here.
Expand Down
30 changes: 12 additions & 18 deletions src/composite/addition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,15 @@ mean((_, fa, fb)::add_args, x::AV) = mean(fa, x) .+ mean(fb, x)
function cov((_, fa, fb)::add_args, x::AV)
return cov(fa, x) .+ cov(fb, x) .+ cov(fa, fb, x, x) .+ cov(fb, fa, x, x)
end
function cov_diag((_, fa, fb)::add_args, x::AV)
return +(
cov_diag(fa, x), cov_diag(fb, x),
cov_diag(fa, fb, x, x), cov_diag(fb, fa, x, x),
)
function var((_, fa, fb)::add_args, x::AV)
return var(fa, x) .+ var(fb, x) .+ var(fa, fb, x, x) .+ var(fb, fa, x, x)
end

function cov((_, fa, fb)::add_args, x::AV, x′::AV)
return cov(fa, x, x′) .+ cov(fb, x, x′) .+ cov(fa, fb, x, x′) .+ cov(fb, fa, x, x′)
end
function cov_diag((_, fa, fb)::add_args, x::AV, x′::AV)
return +(
cov_diag(fa, x, x′), cov_diag(fb, x, x′),
cov_diag(fa, fb, x, x′), cov_diag(fb, fa, x, x′),
)
function var((_, fa, fb)::add_args, x::AV, x′::AV)
return var(fa, x, x′) .+ var(fb, x, x′) .+ var(fa, fb, x, x′) .+ var(fb, fa, x, x′)
end

function cov((_, fa, fb)::add_args, f′::AbstractGP, x::AV, x′::AV)
Expand All @@ -48,11 +42,11 @@ function cov(f::AbstractGP, (_, fa, fb)::add_args, x::AV, x′::AV)
return cov(f, fa, x, x′) .+ cov(f, fb, x, x′)
end

function cov_diag((_, fa, fb)::add_args, f′::AbstractGP, x::AV, x′::AV)
return cov_diag(fa, f′, x, x′) .+ cov_diag(fb, f′, x, x′)
function var((_, fa, fb)::add_args, f′::AbstractGP, x::AV, x′::AV)
return var(fa, f′, x, x′) .+ var(fb, f′, x, x′)
end
function cov_diag(f::AbstractGP, (_, fa, fb)::add_args, x::AV, x′::AV)
return cov_diag(f, fa, x, x′) .+ cov_diag(f, fb, x, x′)
function var(f::AbstractGP, (_, fa, fb)::add_args, x::AV, x′::AV)
return var(f, fa, x, x′) .+ var(f, fb, x, x′)
end


Expand All @@ -72,13 +66,13 @@ mean((_, b, f)::add_known, x::AV) = b.(x) .+ mean(f, x)
mean((_, b, f)::add_known{<:Real}, x::AV) = b .+ mean(f, x)

cov((_, b, f)::add_known, x::AV) = cov(f, x)
cov_diag((_, b, f)::add_known, x::AV) = cov_diag(f, x)
var((_, b, f)::add_known, x::AV) = var(f, x)

cov((_, b, f)::add_known, x::AV, x′::AV) = cov(f, x, x′)
cov_diag((_, b, f)::add_known, x::AV, x′::AV) = cov_diag(f, x, x′)
var((_, b, f)::add_known, x::AV, x′::AV) = var(f, x, x′)

cov((_, b, f)::add_known, f′::AbstractGP, x::AV, x′::AV) = cov(f, f′, x, x′)
cov(f::AbstractGP, (_, b, f′)::add_known, x::AV, x′::AV) = cov(f, f′, x, x′)

cov_diag((_, b, f)::add_known, f′::AbstractGP, x::AV, x′::AV) = cov_diag(f, f′, x, x′)
cov_diag(f::AbstractGP, (_, b, f′)::add_known, x::AV, x′::AV) = cov_diag(f, f′, x, x′)
var((_, b, f)::add_known, f′::AbstractGP, x::AV, x′::AV) = var(f, f′, x, x′)
var(f::AbstractGP, (_, b, f′)::add_known, x::AV, x′::AV) = var(f, f′, x, x′)
8 changes: 4 additions & 4 deletions src/composite/compose.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,16 @@ const comp_args = Tuple{typeof(∘), AbstractGP, Any}
mean((_, f, g)::comp_args, x::AV) = mean(f, g.(x))

cov((_, f, g)::comp_args, x::AV) = cov(f, g.(x))
cov_diag((_, f, g)::comp_args, x::AV) = cov_diag(f, g.(x))
var((_, f, g)::comp_args, x::AV) = var(f, g.(x))

cov((_, f, g)::comp_args, x::AV, x′::AV) = cov(f, g.(x), g.(x′))
cov_diag((_, f, g)::comp_args, x::AV, x′::AV) = cov_diag(f, g.(x), g.(x′))
var((_, f, g)::comp_args, x::AV, x′::AV) = var(f, g.(x), g.(x′))

cov((_, f, g)::comp_args, f′::AbstractGP, x::AV, x′::AV) = cov(f, f′, g.(x), x′)
cov(f::AbstractGP, (_, f′, g)::comp_args, x::AV, x′::AV) = cov(f, f′, x, g.(x′))

cov_diag((_, f, g)::comp_args, f′::AbstractGP, x::AV, x′::AV) = cov_diag(f, f′, g.(x), x′)
cov_diag(f::AbstractGP, (_, f′, g)::comp_args, x::AV, x′::AV) = cov_diag(f, f′, x, g.(x′))
var((_, f, g)::comp_args, f′::AbstractGP, x::AV, x′::AV) = var(f, f′, g.(x), x′)
var(f::AbstractGP, (_, f′, g)::comp_args, x::AV, x′::AV) = var(f, f′, x, g.(x′))


"""
Expand Down
12 changes: 6 additions & 6 deletions src/composite/composite_gp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ CompositeGP(args::Targs, gpc::GPC) where {Targs} = CompositeGP{Targs}(args, gpc)
mean(f::CompositeGP, x::AbstractVector) = mean(f.args, x)

cov(f::CompositeGP, x::AbstractVector) = cov(f.args, x)
cov_diag(f::CompositeGP, x::AbstractVector) = cov_diag(f.args, x)
var(f::CompositeGP, x::AbstractVector) = var(f.args, x)

cov(f::CompositeGP, x::AbstractVector, x′::AbstractVector) = cov(f.args, x, x′)
cov_diag(f::CompositeGP, x::AbstractVector, x′::AbstractVector) = cov_diag(f.args, x, x′)
var(f::CompositeGP, x::AbstractVector, x′::AbstractVector) = var(f.args, x, x′)

function cov(
f::SthenoAbstractGP, f′::SthenoAbstractGP, x::AbstractVector, x′::AbstractVector,
Expand All @@ -39,17 +39,17 @@ function cov(
end
end

function cov_diag(
function var(
f::SthenoAbstractGP, f′::SthenoAbstractGP, x::AbstractVector, x′::AbstractVector,
)
@assert f.gpc === f′.gpc
if f.n === f′.n
return cov_diag(f.args, x, x′)
return var(f.args, x, x′)
elseif f isa WrappedGP && f.n > f′.n || f′ isa WrappedGP && f′.n > f.n
return zeros(length(x))
elseif f.n >= f′.n
return cov_diag(f.args, f′, x, x′)
return var(f.args, f′, x, x′)
else
return cov_diag(f, f′.args, x, x′)
return var(f, f′.args, x, x′)
end
end
12 changes: 6 additions & 6 deletions src/composite/cross.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ function cov((_, fs)::cross_args, x::BlockData)
return Array(mortar(reshape(Cs, :, 1)))
end

function cov_diag((_, fs)::cross_args, x::BlockData)
cs = map(cov_diag, fs, blocks(x))
function var((_, fs)::cross_args, x::BlockData)
cs = map(var, fs, blocks(x))
return Array(mortar(cs))
end

Expand All @@ -37,8 +37,8 @@ function cov((_, fs)::cross_args, x::BlockData, x′::BlockData)
return Array(mortar(reshape(Cs, :, 1)))
end

function cov_diag((_, fs)::cross_args, x::BlockData, x′::BlockData)
cs = map(cov_diag, fs, blocks(x), blocks(x′))
function var((_, fs)::cross_args, x::BlockData, x′::BlockData)
cs = map(var, fs, blocks(x), blocks(x′))
return Array(mortar(cs))
end

Expand All @@ -51,10 +51,10 @@ function cov(f::AbstractGP, (_, fs)::cross_args, x::AV, x′::BlockData)
return Array(mortar(Cs))
end

function cov_diag(args::cross_args, f′::AbstractGP, x::BlockData, x′::AV)
function var(args::cross_args, f′::AbstractGP, x::BlockData, x′::AV)
return diag(cov(args, f′, x, x′))
end
function cov_diag(f::AbstractGP, args::cross_args, x::AV, x′::BlockData)
function var(f::AbstractGP, args::cross_args, x::AV, x′::BlockData)
return diag(cov(f, args, x, x′))
end

Expand Down
26 changes: 13 additions & 13 deletions src/composite/product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,23 @@ function cov((_, σ, g)::prod_args, x::AV)
σx = σ.(x)
return σx .* cov(g, x) .* σx'
end
cov_diag((_, σ, g)::prod_args, x::AV) = σ.(x).^2 .* cov_diag(g, x)
var((_, σ, g)::prod_args, x::AV) = σ.(x).^2 .* var(g, x)

function cov((_, σ, g)::prod_args, x::AV, x′::AV)
return σ.(x) .* cov(g, x, x′) .* σ.(x′)'
end
function cov_diag((_, σ, g)::prod_args, x::AV, x′::AV)
return σ.(x) .* cov_diag(g, x, x′) .* σ.(x′)
function var((_, σ, g)::prod_args, x::AV, x′::AV)
return σ.(x) .* var(g, x, x′) .* σ.(x′)
end

cov((_, σ, f)::prod_args, f′::AbstractGP, x::AV, x′::AV) = σ.(x) .* cov(f, f′, x, x′)
cov(f::AbstractGP, (_, σ, f′)::prod_args, x::AV, x′::AV) = cov(f, f′, x, x′) .* (σ.(x′))'

function cov_diag((_, σ, f)::prod_args, f′::AbstractGP, x::AV, x′::AV)
return σ.(x) .* cov_diag(f, f′, x, x′)
function var((_, σ, f)::prod_args, f′::AbstractGP, x::AV, x′::AV)
return σ.(x) .* var(f, f′, x, x′)
end
function cov_diag(f::AbstractGP, (_, σ, f′)::prod_args, x::AV, x′::AV)
return cov_diag(f, f′, x, x′) .* σ.(x′)
function var(f::AbstractGP, (_, σ, f′)::prod_args, x::AV, x′::AV)
return var(f, f′, x, x′) .* σ.(x′)
end

#
Expand All @@ -50,19 +50,19 @@ end
mean((_, σ, g)::prod_args{<:Real}, x::AV) = σ .* mean(g, x)

cov((_, σ, g)::prod_args{<:Real}, x::AV) =^2) .* cov(g, x)
cov_diag((_, σ, g)::prod_args{<:Real}, x::AV) =^2) .* cov_diag(g, x)
var((_, σ, g)::prod_args{<:Real}, x::AV) =^2) .* var(g, x)

cov((_, σ, g)::prod_args{<:Real}, x::AV, x′::AV) =^2) .* cov(g, x, x′)
cov_diag((_, σ, g)::prod_args{<:Real}, x::AV, x′::AV) =^2) .* cov_diag(g, x, x′)
var((_, σ, g)::prod_args{<:Real}, x::AV, x′::AV) =^2) .* var(g, x, x′)

cov((_, σ, f)::prod_args{<:Real}, f′::AbstractGP, x::AV, x′::AV) = σ .* cov(f, f′, x, x′)
cov(f::AbstractGP, (_, σ, f′)::prod_args{<:Real}, x::AV, x′::AV) = cov(f, f′, x, x′) .* σ

function cov_diag((_, σ, f)::prod_args{<:Real}, f′::AbstractGP, x::AV, x′::AV)
return σ .* cov_diag(f, f′, x, x′)
function var((_, σ, f)::prod_args{<:Real}, f′::AbstractGP, x::AV, x′::AV)
return σ .* var(f, f′, x, x′)
end
function cov_diag(f::AbstractGP, (_, σ, f′)::prod_args{<:Real}, x::AV, x′::AV)
return cov_diag(f, f′, x, x′) .* σ
function var(f::AbstractGP, (_, σ, f′)::prod_args{<:Real}, x::AV, x′::AV)
return var(f, f′, x, x′) .* σ
end

# Use multiplication to define the negation of a GP
Expand Down
20 changes: 10 additions & 10 deletions src/gaussian_process_probabilistic_programme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,41 +59,41 @@ function extract_components(f::GPPP, x::BlockData)
return cross(fs), BlockData(vs)
end

function AbstractGPs.mean(f::GPPP, x::AbstractVector)
function mean(f::GPPP, x::AbstractVector)
fs, vs = extract_components(f, x)
return mean(fs, vs)
end

function AbstractGPs.cov(f::GPPP, x::AbstractVector)
function cov(f::GPPP, x::AbstractVector)
fs, vs = extract_components(f, x)
return cov(fs, vs)
end

function AbstractGPs.cov_diag(f::GPPP, x::AbstractVector)
function var(f::GPPP, x::AbstractVector)
fs, vs = extract_components(f, x)
return cov_diag(fs, vs)
return var(fs, vs)
end

function AbstractGPs.cov(f::GPPP, x::AbstractVector, x′::AbstractVector)
function cov(f::GPPP, x::AbstractVector, x′::AbstractVector)
fs_x, vs_x = extract_components(f, x)
fs_x′, vs_x′ = extract_components(f, x′)
return cov(fs_x, fs_x′, vs_x, vs_x′)
end

function AbstractGPs.cov_diag(f::GPPP, x::AbstractVector, x′::AbstractVector)
function var(f::GPPP, x::AbstractVector, x′::AbstractVector)
fs_x, vs_x = extract_components(f, x)
fs_x′, vs_x′ = extract_components(f, x′)
return cov_diag(fs_x, fs_x′, vs_x, vs_x′)
return var(fs_x, fs_x′, vs_x, vs_x′)
end

function AbstractGPs.mean_and_cov(f::GPPP, x::AbstractVector)
function mean_and_cov(f::GPPP, x::AbstractVector)
fs, vs = extract_components(f, x)
return mean_and_cov(fs, vs)
end

function AbstractGPs.mean_and_cov_diag(f::GPPP, x::AbstractVector)
function mean_and_var(f::GPPP, x::AbstractVector)
fs, vs = extract_components(f, x)
return mean_and_cov_diag(fs, vs)
return mean_and_var(fs, vs)
end


Expand Down
8 changes: 4 additions & 4 deletions src/gp/gp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ wrap(gp::Tgp, gpc::GPC) where {Tgp<:GP} = WrappedGP{Tgp}(gp, gpc)
mean(f::WrappedGP, x::AbstractVector) = mean(f.gp, x)

cov(f::WrappedGP, x::AbstractVector) = cov(f.gp, x)
cov_diag(f::WrappedGP, x::AbstractVector) = cov_diag(f.gp, x)
var(f::WrappedGP, x::AbstractVector) = var(f.gp, x)

cov(f::WrappedGP, x::AbstractVector, x′::AbstractVector) = cov(f.gp, x, x′)
cov_diag(f::WrappedGP, x::AbstractVector, x′::AbstractVector) = cov_diag(f.gp, x, x′)
var(f::WrappedGP, x::AbstractVector, x′::AbstractVector) = var(f.gp, x, x′)

function cov(f::WrappedGP, f′::WrappedGP, x::AbstractVector, x′::AbstractVector)
return f === f′ ? cov(f, x, x′) : zeros(length(x), length(x′))
end
function cov_diag(f::WrappedGP, f′::WrappedGP, x::AbstractVector, x′::AbstractVector)
return f === f′ ? cov_diag(f, x, x′) : zeros(length(x))
function var(f::WrappedGP, f′::WrappedGP, x::AbstractVector, x′::AbstractVector)
return f === f′ ? var(f, x, x′) : zeros(length(x))
end
24 changes: 12 additions & 12 deletions src/sparse_finite_gp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,29 +40,29 @@ end

Base.length(f::SparseFiniteGP) = length(f.fobs)

AbstractGPs.mean(f::SparseFiniteGP) = mean(f.fobs)
mean(f::SparseFiniteGP) = mean(f.fobs)

const __covariance_error = "The covariance matrix of a sparse GP can often be dense and " *
"can cause the computer to run out of memory. If you are sure you have enough " *
"memory, you can use `cov(f.fobs)`."

AbstractGPs.cov(f::SparseFiniteGP) = error(__covariance_error)
cov(f::SparseFiniteGP) = error(__covariance_error)

AbstractGPs.marginals(f::SparseFiniteGP) = marginals(f.fobs)
marginals(f::SparseFiniteGP) = marginals(f.fobs)

AbstractGPs.rand(rng::AbstractRNG, f::SparseFiniteGP, N::Int) = rand(rng, f.fobs, N)
AbstractGPs.rand(f::SparseFiniteGP, N::Int) = rand(Random.GLOBAL_RNG, f, N)
AbstractGPs.rand(rng::AbstractRNG, f::SparseFiniteGP) = vec(rand(rng, f, 1))
AbstractGPs.rand(f::SparseFiniteGP) = vec(rand(f, 1))
rand(rng::AbstractRNG, f::SparseFiniteGP, N::Int) = rand(rng, f.fobs, N)
rand(f::SparseFiniteGP, N::Int) = rand(Random.GLOBAL_RNG, f, N)
rand(rng::AbstractRNG, f::SparseFiniteGP) = vec(rand(rng, f, 1))
rand(f::SparseFiniteGP) = vec(rand(f, 1))

AbstractGPs.elbo(f::SparseFiniteGP, y::AV{<:Real}) = elbo(f.fobs, y, f.finducing)
elbo(f::SparseFiniteGP, y::AV{<:Real}) = elbo(f.fobs, y, f.finducing)

AbstractGPs.logpdf(f::SparseFiniteGP, y::AV{<:Real}) = elbo(f.fobs, y, f.finducing)
logpdf(f::SparseFiniteGP, y::AV{<:Real}) = elbo(f.fobs, y, f.finducing)

function AbstractGPs.logpdf(f::SparseFiniteGP, Y::AbstractMatrix{<:Real})
function logpdf(f::SparseFiniteGP, Y::AbstractMatrix{<:Real})
return map(y -> logpdf(f, y), eachcol(Y))
end

function AbstractGPs.posterior(f::SparseFiniteGP, y::AbstractVector{<:Real})
return AbstractGPs.approx_posterior(AbstractGPs.VFE(), f.fobs, y, f.finducing)
function posterior(f::SparseFiniteGP, y::AbstractVector{<:Real})
return approx_posterior(AbstractGPs.VFE(), f.fobs, y, f.finducing)
end
Loading

4 comments on commit b4e2d20

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Changing package repo URL not allowed, please submit a pull request with the URL change to the target registry and retry.

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/33452

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.1 -m "<description of version>" b4e2d20f973a0816272fdf07bdd5896a614b99e1
git push origin v0.7.1

Please sign in to comment.