Skip to content

Commit

Permalink
Add public interface for adding new transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Apr 16, 2022
1 parent e6379d0 commit d905fa7
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
20 changes: 20 additions & 0 deletions stheno/model/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,26 @@ def _update(self, p, mean, kernel, left_rule, right_rule=None):

return p

def add_gp(self, mean, kernel, left_rule, right_rule=None):
"""Add a new GP to the graph with a given mean function and kernel.
Args:
mean (:class:`mlkernels.Mean`): Mean function.
kernel (:class:`mlkernels.Kernel`): Kernel.
left_rule (function): Function that takes in another process `i`
and which return the covariance between the new process (left argument)
and process `i` (right argument). This function can make use of
means and kernels available in the property :attr:`.Measure.means`
and :attr:`.Measure.kernels`.
right_rule (function, optional): Like `left_rule`, but the other way around.
Returns:
:class:`.gp.GP`: New GP.
"""
p = GP()
self._update(p, mean, kernel, left_rule, right_rule=None)
return p

@_dispatch
def __call__(self, p: GP):
# Make a new GP with `self` as the prior.
Expand Down
22 changes: 22 additions & 0 deletions tests/model/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,28 @@ def test_logpdf(PseudoObs):
approx(m.logpdf(obs), p3(x3, 1).logpdf(y3))


def test_manual_new_gp():
m = Measure()
p1 = GP(EQ(), measure=m)
p2 = GP(EQ(), measure=m)
p_sum = p1 + p2

p1_equivalent = m.add_gp(
m.means[p_sum] - m.means[p2],
(
m.kernels[p_sum]
+ m.kernels[p2]
- m.kernels[p_sum, p2]
- m.kernels[p2, p_sum]
),
lambda j: m.kernels[p_sum, j] - m.kernels[p2, j],
)

x = B.linspace(0, 10, 5)
s1, s2 = m.sample(p1(x), p1_equivalent(x))
approx(s1, s2, rtol=1e-5)


def test_stretching():
# Test construction:
p = GP(TensorProductMean(lambda x: x**2), EQ())
Expand Down

0 comments on commit d905fa7

Please sign in to comment.