Skip to content

Commit

Permalink
Improve efficiency of marginals
Browse files Browse the repository at this point in the history
  • Loading branch information
wesselb committed Apr 1, 2022
1 parent a0188e3 commit 574a3c7
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 16 deletions.
7 changes: 6 additions & 1 deletion stheno/model/fdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,12 @@ def __init__(self, p: PromisedGP, x, noise):
self.p = p
self.x = x
self.noise = _noise_as_matrix(noise, B.dtype(x), infer_size(p.kernel, x))
Normal.__init__(self, lambda: p.mean(x), lambda: p.kernel(x) + self.noise)
Normal.__init__(
self,
lambda: p.mean(x),
lambda: B.add(p.kernel(x), self.noise),
lambda: B.add(B.squeeze(p.kernel.elwise(x), axis=-1), B.diag(self.noise)),
)

@_dispatch
def __init__(self, p: PromisedGP, x):
Expand Down
56 changes: 42 additions & 14 deletions stheno/random.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import warnings
from types import FunctionType

from lab import B
from matrix import AbstractMatrix, Zero
from plum import convert, Union
from wbml.util import indented_kv

from . import _dispatch, BreakingChangeWarning
from . import _dispatch

__all__ = ["Random", "RandomProcess", "RandomVector", "Normal"]

Expand Down Expand Up @@ -61,22 +60,38 @@ def __init__(
self._mean = mean
self._mean_is_zero = None
self._var = var
self._var_diag = None
self._construct_var_diag = None

@_dispatch
def __init__(self, var: Union[B.Numeric, AbstractMatrix]):
Normal.__init__(self, 0, var)

@_dispatch
def __init__(self, construct_mean: FunctionType, construct_var: FunctionType):
def __init__(
self,
mean: FunctionType,
var: FunctionType,
var_diag: Union[FunctionType, None] = None,
):
self._mean = None
self._construct_mean = construct_mean
self._construct_mean = mean
self._mean_is_zero = None
self._var = None
self._construct_var = construct_var
self._construct_var = var
self._var_diag = None
self._construct_var_diag = var_diag

@_dispatch
def __init__(self, construct_var: FunctionType):
Normal.__init__(self, lambda: 0, construct_var)
def __init__(
self,
var: FunctionType,
# Require this one as a keyword argument to prevent ambiguity with the usual
# two-argument method.
*,
var_diag: Union[FunctionType, None] = None,
):
Normal.__init__(self, lambda: 0, var, var_diag)

def _resolve_mean(self, construct_zeros):
if self._mean is None:
Expand All @@ -92,6 +107,13 @@ def _resolve_var(self):
# Ensure that the variance is a structured matrix for efficient operations.
self._var = convert(self._var, AbstractMatrix)

def _resolve_var_diag(self):
if self._var_diag is None:
if self._construct_var_diag is not None:
self._var_diag = self._construct_var_diag()
else:
self._var_diag = B.diag(self.var)

def __str__(self):
return (
f"<Normal:\n"
Expand Down Expand Up @@ -140,9 +162,15 @@ def var(self):
self._resolve_var()
return self._var

@property
def var_diag(self):
"""Diagonal of the variance."""
self._resolve_var_diag()
return self._var_diag

@property
def dtype(self):
"""Data type."""
"""Data type of the variance."""
return B.dtype(self.var)

@property
Expand All @@ -161,11 +189,13 @@ def marginals(self):
Returns:
tuple: A tuple containing the marginal means and marginal variances.
"""
mean, var_diag = self.mean, self.var_diag
# It can happen that the variances are slightly negative due to numerical noise.
# Prevent NaNs from the following square root by taking the maximum with zero.
# Also strip away any matrix structure.
return (
B.squeeze(B.dense(self.mean)),
B.maximum(B.diag(self.var), B.cast(self.dtype, 0)),
B.squeeze(B.dense(mean), axis=-1),
B.maximum(B.dense(var_diag), B.zero(var_diag)),
)

def marginal_credible_bounds(self):
Expand All @@ -175,10 +205,8 @@ def marginal_credible_bounds(self):
tuple: A tuple containing the marginal means and marginal lower and
upper 95% central credible interval bounds.
"""
warnings.simplefilter(category=BreakingChangeWarning, action="ignore")
mean, variances = self.marginals()
warnings.simplefilter(category=BreakingChangeWarning, action="default")
error = 1.96 * B.sqrt(variances)
mean, var = self.marginals()
error = 1.96 * B.sqrt(var)
return mean, mean - error, mean + error

def logpdf(self, x):
Expand Down
15 changes: 15 additions & 0 deletions tests/model/test_fdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,18 @@ def test_fdd_take():
# Test that only masks are supported, for now.
with pytest.raises(AssertionError):
B.take(fdd, np.array([1, 2]))


def test_fdd_diag():
p = GP(EQ())

# Sample observations.
x = B.linspace(0, 5, 5)
y = p(x, 0.1).sample()

# Compute posterior.
p = p | (p(x, 0.1), y)

# Check that the diagonal is computed correctly.
fdd = p(B.linspace(0, 5, 10), 0.2)
approx(fdd.var_diag, B.diag(fdd.var))
15 changes: 14 additions & 1 deletion tests/model/test_gp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
from time import time
from lab import B
from mlkernels import (
Linear,
Expand Down Expand Up @@ -179,7 +180,6 @@ def test_marginals():

# Check that `marginals` outputs the right thing.
mean, var = p(x).marginals()
var = B.diag(p.kernel(x))
approx(mean, p.mean(x)[:, 0])
approx(var, B.diag(p.kernel(x)))

Expand All @@ -196,3 +196,16 @@ def test_marginals():
mean, var = post(p)(x + 100).marginals()
approx(mean, p.mean(x + 100)[:, 0])
approx(var, B.diag(p.kernel(x + 100)))


def test_marginal_credible_bounds_efficiency():
p = GP(EQ())
x = B.linspace(0, 5, 5)
y = p(x, 0.1).sample()
p = p | (p(x, 0.1), y)

# Check that the computation at 10_000 points takes at most one second.
x = B.linspace(0, 5, 10_000)
start = time()
p(x, 0.2).marginal_credible_bounds()
assert time() - start < 1
20 changes: 20 additions & 0 deletions tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def test_normal_mean_is_zero():

def test_normal_lazy_zero_mean():
dist = Normal(lambda: B.eye(3))
assert dist._mean is None
assert dist._var is None

assert dist.mean_is_zero
assert dist._mean is 0
Expand All @@ -81,6 +83,8 @@ def test_normal_lazy_zero_mean():

def test_normal_lazy_nonzero_mean():
dist = Normal(lambda: B.ones(3, 1), lambda: B.eye(3))
assert dist._mean is None
assert dist._var is None

assert not dist.mean_is_zero
approx(dist._mean, B.ones(3, 1))
Expand All @@ -92,6 +96,22 @@ def test_normal_lazy_nonzero_mean():
approx(dist.var, B.eye(3))


def test_normal_lazy_var_diag():
dist = Normal(lambda: B.eye(3))
assert dist._var is None
assert dist._var_diag is None

approx(dist.var_diag, B.ones(3))
approx(dist._var, B.eye(3))

dist = Normal(lambda: B.eye(3), var_diag=lambda: 9)
assert dist._var is None
assert dist._var_diag is None

assert dist.var_diag == 9
assert dist._var is None


def test_normal_m2(normal1):
approx(normal1.m2, normal1.var + normal1.mean @ normal1.mean.T)

Expand Down
2 changes: 2 additions & 0 deletions todo.tasks
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ TODO:
Bugs:
☐ MOK: check that it contains the fdd
☐ Noisy mixture posterior is bugged?
☐ Do not freely construct variance in `Normal.dtype` @high
Throw exception instead?

Misc:

Expand Down

0 comments on commit 574a3c7

Please sign in to comment.