Skip to content

Commit

Permalink
Add mean_var
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Apr 16, 2022
1 parent 0c29a3e commit 67de930
Show file tree
Hide file tree
Showing 11 changed files with 221 additions and 73 deletions.
21 changes: 13 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -375,13 +375,18 @@ f(x, noise) # Additional noise with variance `noise`

Things you can do with a finite-dimensional distribution:

*
*
Use `f(x).mean` to compute the mean.

*
*
Use `f(x).var` to compute the variance.

*
Use `f(x).mean_var` to compute simultaneously compute the mean and variance.
This can be substantially more efficient than calling first `f(x).mean` and then
`f(x).var`.

*
*
Use `Normal.sample` to sample.

Definition:
Expand All @@ -396,10 +401,10 @@ Things you can do with a finite-dimensional distribution:
f(x).sample(n, noise=noise) # Produce `n` samples with additional noise variance `noise`.
```

*
*
Use `f(x).logpdf(y)` to compute the logpdf of some data `y`.

*
*
Use `means, variances = f(x).marginals()` to efficiently compute the marginal means
and marginal variances.

Expand All @@ -410,7 +415,7 @@ Things you can do with a finite-dimensional distribution:
(array([0., 0., 0.]), np.array([1., 1., 1.]))
```

*
*
Use `means, lowers, uppers = f(x).marginal_credible_bounds()` to efficiently compute
the means and the marginal lower and upper 95% central credible region bounds.

Expand All @@ -421,7 +426,7 @@ Things you can do with a finite-dimensional distribution:
(array([0., 0., 0.]), array([-1.96, -1.96, -1.96]), array([1.96, 1.96, 1.96]))
```

*
*
Use `Measure.logpdf` to compute the joint logpdf of multiple observations.

Definition, where `prior = Measure()`:
Expand All @@ -432,7 +437,7 @@ Things you can do with a finite-dimensional distribution:
prior.logpdf((f1(x1), y1), (f2(x2), y2), ...)
```

*
*
Use `Measure.sample` to jointly sample multiple observations.

Definition, where `prior = Measure()`:
Expand Down
21 changes: 13 additions & 8 deletions README_without_examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,13 +363,18 @@ f(x, noise) # Additional noise with variance `noise`

Things you can do with a finite-dimensional distribution:

*
*
Use `f(x).mean` to compute the mean.

*
*
Use `f(x).var` to compute the variance.

*
Use `f(x).mean_var` to compute simultaneously compute the mean and variance.
This can be substantially more efficient than calling first `f(x).mean` and then
`f(x).var`.

*
*
Use `Normal.sample` to sample.

Definition:
Expand All @@ -384,10 +389,10 @@ Things you can do with a finite-dimensional distribution:
f(x).sample(n, noise=noise) # Produce `n` samples with additional noise variance `noise`.
```

*
*
Use `f(x).logpdf(y)` to compute the logpdf of some data `y`.

*
*
Use `means, variances = f(x).marginals()` to efficiently compute the marginal means
and marginal variances.

Expand All @@ -398,7 +403,7 @@ Things you can do with a finite-dimensional distribution:
(array([0., 0., 0.]), np.array([1., 1., 1.]))
```

*
*
Use `means, lowers, uppers = f(x).marginal_credible_bounds()` to efficiently compute
the means and the marginal lower and upper 95% central credible region bounds.

Expand All @@ -409,7 +414,7 @@ Things you can do with a finite-dimensional distribution:
(array([0., 0., 0.]), array([-1.96, -1.96, -1.96]), array([1.96, 1.96, 1.96]))
```

*
*
Use `Measure.logpdf` to compute the joint logpdf of multiple observations.

Definition, where `prior = Measure()`:
Expand All @@ -420,7 +425,7 @@ Things you can do with a finite-dimensional distribution:
prior.logpdf((f1(x1), y1), (f2(x2), y2), ...)
```

*
*
Use `Measure.sample` to jointly sample multiple observations.

Definition, where `prior = Measure()`:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"plum-dispatch>=1.5.3",
"backends>=1.4.11",
"backends-matrix>=1.2.11",
"mlkernels>=0.3.4",
"mlkernels>=0.3.6",
"wbml>=0.3.3",
]

Expand Down
17 changes: 16 additions & 1 deletion stheno/model/fdd.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import mlkernels
from lab import B
from matrix import AbstractMatrix, Dense, Zero, Diagonal
from mlkernels import Kernel, num_elements
Expand Down Expand Up @@ -60,11 +61,25 @@ def __init__(self, p: PromisedGP, x, noise):
self.p = p
self.x = x
self.noise = _noise_as_matrix(noise, B.dtype(x), infer_size(p.kernel, x))

def var_diag():
return B.add(B.squeeze(p.kernel.elwise(x), axis=-1), B.diag(self.noise))

def mean_var():
mean, var = mlkernels.mean_var(p.mean, p.kernel, x)
return mean, B.add(var, self.noise)

def mean_var_diag():
mean, var_diag = mlkernels.mean_var_diag(p.mean, p.kernel, x)
return mean, B.add(B.squeeze(var_diag, axis=-1), B.diag(self.noise))

Normal.__init__(
self,
lambda: p.mean(x),
lambda: B.add(p.kernel(x), self.noise),
lambda: B.add(B.squeeze(p.kernel.elwise(x), axis=-1), B.diag(self.noise)),
var_diag=var_diag,
mean_var=mean_var,
mean_var_diag=mean_var_diag,
)

@_dispatch
Expand Down
1 change: 0 additions & 1 deletion stheno/model/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def mul(self, p_mul: GP, other, p: GP):
p_mul (:class:`.gp.GP`): GP that is the product.
obj1 (object): First factor in the product.
obj2 (object): Second factor in the product.
other (object): Other object in the product.
Returns:
:class:`.gp.GP`: The GP corresponding to the product.
Expand Down
2 changes: 1 addition & 1 deletion stheno/model/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def _compute(self, measure):
# Optimal mean:
y_bar = B.subtract(B.uprank(self.y), measure.means[p_x](x))
prod_y_bar = B.iqf(K_n, B.transpose(iLz_Kzx), y_bar)
# TODO: Absorb `L_z` in the posterior mean for better stability.
# TODO: Absorb `L_z` in the posterior mean for better stability?
mu = B.add(measure.means[p_z](z), B.iqf(A, B.transpose(L_z), prod_y_bar))
self._mu_store[id(measure)] = mu

Expand Down
84 changes: 61 additions & 23 deletions stheno/random.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from types import FunctionType

from lab import B
from matrix import AbstractMatrix, Zero
from matrix import AbstractMatrix, Zero, Diagonal
from plum import convert, Union
from wbml.util import indented_kv

Expand Down Expand Up @@ -72,7 +72,10 @@ def __init__(
self,
mean: FunctionType,
var: FunctionType,
*,
var_diag: Union[FunctionType, None] = None,
mean_var: Union[FunctionType, None] = None,
mean_var_diag: Union[FunctionType, None] = None,
):
self._mean = None
self._construct_mean = mean
Expand All @@ -81,17 +84,12 @@ def __init__(
self._construct_var = var
self._var_diag = None
self._construct_var_diag = var_diag
self._construct_mean_var = mean_var
self._construct_mean_var_diag = mean_var_diag

@_dispatch
def __init__(
self,
var: FunctionType,
# Require this one as a keyword argument to prevent ambiguity with the usual
# two-argument method.
*,
var_diag: Union[FunctionType, None] = None,
):
Normal.__init__(self, lambda: 0, var, var_diag)
def __init__(self, var: FunctionType, **kw_args):
Normal.__init__(self, lambda: 0, var, **kw_args)

def _resolve_mean(self, construct_zeros):
if self._mean is None:
Expand Down Expand Up @@ -146,41 +144,57 @@ def __repr__(self):

@property
def mean(self):
"""Mean."""
"""column vector: Mean."""
self._resolve_mean(construct_zeros=True)
return self._mean

@property
def mean_is_zero(self):
"""The mean is zero."""
"""bool: The mean is zero."""
self._resolve_mean(construct_zeros=False)
return self._mean_is_zero

@property
def var(self):
"""Variance."""
"""matrix: Variance."""
self._resolve_var()
return self._var

@property
def var_diag(self):
"""Diagonal of the variance."""
"""vector: Diagonal of the variance."""
self._resolve_var_diag()
return self._var_diag

@property
def mean_var(self):
"""tuple[column vector, matrix]: Mean and variance."""
if self._mean is not None and self._var is not None:
return self._mean, self._var
elif self._mean is not None:
return self._mean, self.var
elif self._var is not None:
return self.mean, self._var
else:
if self._construct_mean_var is not None:
self._mean, self._var = self._construct_mean_var()
self._resolve_mean(construct_zeros=True)
self._resolve_var()
return self.mean, self.var

@property
def dtype(self):
"""Data type of the variance."""
"""dtype: Data type of the variance."""
return B.dtype(self.var)

@property
def dim(self):
"""Dimensionality."""
"""int: Dimensionality."""
return B.shape_matrix(self.var)[0]

@property
def m2(self):
"""Second moment."""
"""matrix: Second moment."""
return self.var + B.outer(B.squeeze(self.mean))

def marginals(self):
Expand All @@ -189,7 +203,17 @@ def marginals(self):
Returns:
tuple: A tuple containing the marginal means and marginal variances.
"""
mean, var_diag = self.mean, self.var_diag
if self._mean is not None and self._var_diag is not None:
mean, var_diag = self._mean, self._var_diag
elif self._mean is not None:
mean, var_diag = self._mean, self.var_diag
elif self._var_diag is not None:
mean, var_diag = self.mean, self._var_diag
else:
if self._construct_mean_var_diag is not None:
self._mean, self._var_diag = self._construct_mean_var_diag()
self._resolve_mean(construct_zeros=True)
mean, var_diag = self.mean, self.var_diag
# It can happen that the variances are slightly negative due to numerical noise.
# Prevent NaNs from the following square root by taking the maximum with zero.
# Also strip away any matrix structure.
Expand All @@ -209,6 +233,14 @@ def marginal_credible_bounds(self):
error = 1.96 * B.sqrt(var)
return mean, mean - error, mean + error

def diagonalise(self):
"""Diagonalise the normal distribution by setting the correlations to zero.
Returns:
:class:`.Normal`: Diagonal version of the distribution.
"""
return Normal(self.mean, Diagonal(self.var_diag))

def logpdf(self, x):
"""Compute the log-pdf.
Expand Down Expand Up @@ -328,26 +360,32 @@ def sample(self, num: B.Int = 1, noise=None):

@_dispatch
def __add__(self, other: B.Numeric):
return Normal(self.mean + other, self.var)
return Normal(B.add(self.mean, other), self.var)

@_dispatch
def __add__(self, other: "Normal"):
return Normal(B.add(self.mean, other.mean), B.add(self.var, other.var))
return Normal(
B.add(self.mean, other.mean),
B.add(self.var, other.var),
)

@_dispatch
def __mul__(self, other: B.Numeric):
return Normal(B.multiply(self.mean, other), B.multiply(self.var, other**2))
return Normal(
B.multiply(self.mean, other),
B.multiply(self.var, B.multiply(other, other)),
)

def lmatmul(self, other):
return Normal(
B.matmul(other, self.mean),
B.matmul(B.matmul(other, self.var), other, tr_b=True),
B.matmul(other, self.var, other, tr_c=True),
)

def rmatmul(self, other):
return Normal(
B.matmul(other, self.mean, tr_a=True),
B.matmul(B.matmul(other, self.var, tr_a=True), other),
B.matmul(other, self.var, other, tr_a=True),
)


Expand Down
17 changes: 14 additions & 3 deletions tests/model/test_fdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_fdd_take():
B.take(fdd, np.array([1, 2]))


def test_fdd_diag():
def test_fdd_properties():
p = GP(EQ())

# Sample observations.
Expand All @@ -117,6 +117,17 @@ def test_fdd_diag():
# Compute posterior.
p = p | (p(x, 0.1), y)

# Check that the diagonal is computed correctly.
fdd = p(B.linspace(0, 5, 10), 0.2)
approx(fdd.var_diag, B.diag(fdd.var))
mean, var = fdd.mean, fdd.var

# Check `var_diag`.
fdd = p(B.linspace(0, 5, 10), 0.2)
approx(fdd.var_diag, B.diag(var))

# Check `mean_var`.
fdd = p(B.linspace(0, 5, 10), 0.2)
approx(fdd.mean_var, (mean, var))

# Check `marginals()`.
fdd = p(B.linspace(0, 5, 10), 0.2)
approx(fdd.marginals(), (B.flatten(mean), B.diag(var)))
Loading

0 comments on commit 67de930

Please sign in to comment.