Skip to content

Commit

Permalink
Utilise AbstractGPs.TestUtils (#206)
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Jul 22, 2021
1 parent 37b4697 commit 7d173b1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 38 deletions.
36 changes: 2 additions & 34 deletions test/gaussian_process_probabilistic_programme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,40 +82,8 @@
GPPPInput(:f1, randn(4)),
),
]

atol=1e-9
rtol=1e-9

m = mean(f, x0)
@test m isa AbstractVector{<:Real}
@test length(m) == length(x0)

@assert length(x0) length(x1)

# Check that unary cov is consistent with binary cov and conforms to the API
K_x0 = cov(f, x0)
@test K_x0 isa AbstractMatrix{<:Real}
@test size(K_x0) == (length(x0), length(x0))
@test K_x0 cov(f, x0, x0) atol=atol rtol=rtol
@test minimum(eigvals(K_x0)) > -atol
@test K_x0 K_x0' atol=atol rtol=rtol

# Check that single-process binary cov is consistent with single-process binary-cov
K_x0_x1 = cov(f, x0, x1)
@test K_x0_x1 isa AbstractMatrix{<:Real}
@test size(K_x0_x1) == (length(x0), length(x1))

# Check that single-process binary var is consistent.
K_x0_x0_diag = var(f, x0, x0)
@test K_x0_x0_diag isa AbstractVector{<:Real}
@test length(K_x0_x0_diag) == length(x0)
@test K_x0_x0_diag diag(cov(f, x0, x0)) atol=atol rtol=rtol

# Check that unary var conforms to the API and is consistent with unary cov
K_x0_diag = var(f, x0)
@test K_x0_diag isa AbstractVector{<:Real}
@test length(K_x0_diag) == length(x0)
@test K_x0_diag diag(cov(f, x0)) atol=atol rtol=rtol
rng = MersenneTwister(123456)
AbstractGPs.TestUtils.test_internal_abstractgps_interface(rng, f, x0, x1)
end

@timedtestset "gppp macro" begin
Expand Down
5 changes: 1 addition & 4 deletions test/gp/gp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@

@test mean(f, x) == AbstractGPs._map(m, x)
@test cov(f, x) == kernelmatrix(k, x)
@test var(f, x) == diag(cov(f, x))
@test cov(f, x, x) == kernelmatrix(k, x, x)
@test cov(f, x, x′) == kernelmatrix(k, x, x′)
@test cov(f, x, x′) cov(f, x′, x)'
AbstractGPs.TestUtils.test_internal_abstractgps_interface(rng, f, x, x′)
end

# Test the creation of indepenent GPs.
Expand Down

0 comments on commit 7d173b1

Please sign in to comment.