From d905fa7c51d44b35e8e6c519918cded2528a8f14 Mon Sep 17 00:00:00 2001 From: Wessel Bruinsma Date: Sat, 16 Apr 2022 17:01:23 +0100 Subject: [PATCH] Add public interface for adding new transforms --- stheno/model/measure.py | 20 ++++++++++++++++++++ tests/model/test_model.py | 22 ++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/stheno/model/measure.py b/stheno/model/measure.py index 06d6449..43bfb39 100644 --- a/stheno/model/measure.py +++ b/stheno/model/measure.py @@ -122,6 +122,26 @@ def _update(self, p, mean, kernel, left_rule, right_rule=None): return p + def add_gp(self, mean, kernel, left_rule, right_rule=None): + """Add a new GP to the graph with a given mean function and kernel. + + Args: + mean (:class:`mlkernels.Mean`): Mean function. + kernel (:class:`mlkernels.Kernel`): Kernel. + left_rule (function): Function that takes in another process `i` + and which return the covariance between the new process (left argument) + and process `i` (right argument). This function can make use of + means and kernels available in the property :attr:`.Measure.means` + and :attr:`.Measure.kernels`. + right_rule (function, optional): Like `left_rule`, but the other way around. + + Returns: + :class:`.gp.GP`: New GP. + """ + p = GP() + self._update(p, mean, kernel, left_rule, right_rule=None) + return p + @_dispatch def __call__(self, p: GP): # Make a new GP with `self` as the prior. diff --git a/tests/model/test_model.py b/tests/model/test_model.py index 318518d..91457e3 100644 --- a/tests/model/test_model.py +++ b/tests/model/test_model.py @@ -408,6 +408,28 @@ def test_logpdf(PseudoObs): approx(m.logpdf(obs), p3(x3, 1).logpdf(y3)) +def test_manual_new_gp(): + m = Measure() + p1 = GP(EQ(), measure=m) + p2 = GP(EQ(), measure=m) + p_sum = p1 + p2 + + p1_equivalent = m.add_gp( + m.means[p_sum] - m.means[p2], + ( + m.kernels[p_sum] + + m.kernels[p2] + - m.kernels[p_sum, p2] + - m.kernels[p2, p_sum] + ), + lambda j: m.kernels[p_sum, j] - m.kernels[p2, j], + ) + + x = B.linspace(0, 10, 5) + s1, s2 = m.sample(p1(x), p1_equivalent(x)) + approx(s1, s2, rtol=1e-5) + + def test_stretching(): # Test construction: p = GP(TensorProductMean(lambda x: x**2), EQ())