diff --git a/stheno/model/fdd.py b/stheno/model/fdd.py index 90d65f2..49c4f86 100644 --- a/stheno/model/fdd.py +++ b/stheno/model/fdd.py @@ -60,7 +60,12 @@ def __init__(self, p: PromisedGP, x, noise): self.p = p self.x = x self.noise = _noise_as_matrix(noise, B.dtype(x), infer_size(p.kernel, x)) - Normal.__init__(self, lambda: p.mean(x), lambda: p.kernel(x) + self.noise) + Normal.__init__( + self, + lambda: p.mean(x), + lambda: B.add(p.kernel(x), self.noise), + lambda: B.add(B.squeeze(p.kernel.elwise(x), axis=-1), B.diag(self.noise)), + ) @_dispatch def __init__(self, p: PromisedGP, x): diff --git a/stheno/random.py b/stheno/random.py index b2aeb5f..b1cc6cd 100644 --- a/stheno/random.py +++ b/stheno/random.py @@ -1,4 +1,3 @@ -import warnings from types import FunctionType from lab import B @@ -6,7 +5,7 @@ from plum import convert, Union from wbml.util import indented_kv -from . import _dispatch, BreakingChangeWarning +from . import _dispatch __all__ = ["Random", "RandomProcess", "RandomVector", "Normal"] @@ -61,22 +60,38 @@ def __init__( self._mean = mean self._mean_is_zero = None self._var = var + self._var_diag = None + self._construct_var_diag = None @_dispatch def __init__(self, var: Union[B.Numeric, AbstractMatrix]): Normal.__init__(self, 0, var) @_dispatch - def __init__(self, construct_mean: FunctionType, construct_var: FunctionType): + def __init__( + self, + mean: FunctionType, + var: FunctionType, + var_diag: Union[FunctionType, None] = None, + ): self._mean = None - self._construct_mean = construct_mean + self._construct_mean = mean self._mean_is_zero = None self._var = None - self._construct_var = construct_var + self._construct_var = var + self._var_diag = None + self._construct_var_diag = var_diag @_dispatch - def __init__(self, construct_var: FunctionType): - Normal.__init__(self, lambda: 0, construct_var) + def __init__( + self, + var: FunctionType, + # Require this one as a keyword argument to prevent ambiguity with the usual + # two-argument method. + *, + var_diag: Union[FunctionType, None] = None, + ): + Normal.__init__(self, lambda: 0, var, var_diag) def _resolve_mean(self, construct_zeros): if self._mean is None: @@ -92,6 +107,13 @@ def _resolve_var(self): # Ensure that the variance is a structured matrix for efficient operations. self._var = convert(self._var, AbstractMatrix) + def _resolve_var_diag(self): + if self._var_diag is None: + if self._construct_var_diag is not None: + self._var_diag = self._construct_var_diag() + else: + self._var_diag = B.diag(self.var) + def __str__(self): return ( f"