Skip to content

Commit

Permalink
Clarify behavior on standard deviations with <1 degree of freedom and…
Browse files Browse the repository at this point in the history
… silence some unit test warnings (#2357)

Summary:
## Motivation

Unit tests were producing a lot of warnings about taking standard deviations across fewer than 2 observations, and it was not clear to me if these warnings were legitimate in context.
* For checking the standardization of input data, no longer check the standard deviation if there is just one observation.
* For the standardize input transform, explicitly set standard deviations to 1 when there is only one observation. This actually matches the legacy behavior, but previously it wasn't clear because the standard deviation would become NaN before being corrected to 1.
* Error on attempting to standardize 0 observations. This never worked so now it is more clear.

Pull Request resolved: #2357

Test Plan:
Added units

## Related PRs

Reviewed By: Balandat

Differential Revision: D57931412

Pulled By: esantorella

fbshipit-source-id: 36a9c81a950a0b92749673fdd22aec62b45aaae9
  • Loading branch information
esantorella authored and facebook-github-bot committed Jun 11, 2024
1 parent 1e73b30 commit f79d608
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 31 deletions.
10 changes: 9 additions & 1 deletion botorch/models/transforms/outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,15 @@ def forward(
f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected "
f"{self._m}."
)
stdvs = Y.std(dim=-2, keepdim=True)
if Y.shape[-2] < 1:
raise ValueError(f"Can't standardize with no observations. {Y.shape=}.")

elif Y.shape[-2] == 1:
stdvs = torch.ones(
(*Y.shape[:-2], 1, Y.shape[-1]), dtype=Y.dtype, device=Y.device
)
else:
stdvs = Y.std(dim=-2, keepdim=True)
stdvs = stdvs.where(stdvs >= self._min_stdv, torch.full_like(stdvs, 1.0))
means = Y.mean(dim=-2, keepdim=True)
if self._outputs is not None:
Expand Down
33 changes: 23 additions & 10 deletions botorch/models/utils/assorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def check_min_max_scaling(
)
if raise_on_fail:
raise InputDataError(msg)
warnings.warn(msg, InputDataWarning)
warnings.warn(msg, InputDataWarning, stacklevel=2)


def check_standardization(
Expand All @@ -191,15 +191,28 @@ def check_standardization(
raise_on_fail: If True, raise an exception instead of a warning.
"""
with torch.no_grad():
Ymean, Ystd = torch.mean(Y, dim=-2), torch.std(Y, dim=-2)
if torch.abs(Ymean).max() > atol_mean or torch.abs(Ystd - 1).max() > atol_std:
msg = (
f"Input data is not standardized (mean = {Ymean}, std = {Ystd}). "
"Please consider scaling the input to zero mean and unit variance."
)
if raise_on_fail:
raise InputDataError(msg)
warnings.warn(msg, InputDataWarning)
Ymean = torch.mean(Y, dim=-2)
mean_not_zero = torch.abs(Ymean).max() > atol_mean
if Y.shape[-2] <= 1:
if mean_not_zero:
msg = (
f"Data is not standardized (mean = {Ymean}). "
"Please consider scaling the input to zero mean and unit variance."
)
if raise_on_fail:
raise InputDataError(msg)
warnings.warn(msg, InputDataWarning, stacklevel=2)
else:
Ystd = torch.std(Y, dim=-2)
std_not_one = torch.abs(Ystd - 1).max() > atol_std
if mean_not_zero or std_not_one:
msg = (
f"Data is not standardized (std = {Ystd}, mean = {Ymean}). "
"Please consider scaling the input to zero mean and unit variance."
)
if raise_on_fail:
raise InputDataError(msg)
warnings.warn(msg, InputDataWarning, stacklevel=2)


def validate_input_scaling(
Expand Down
2 changes: 1 addition & 1 deletion botorch/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def setUp(self, suppress_input_warnings: bool = True) -> None:
)
warnings.filterwarnings(
"ignore",
message="Input data is not standardized.",
message="Data is not standardized.",
category=InputDataWarning,
)
warnings.filterwarnings(
Expand Down
32 changes: 21 additions & 11 deletions test/models/transforms/test_outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,14 @@ def test_is_linear(self) -> None:
)
self.assertEqual(posterior_is_gpt, transform._is_linear)

def test_standardize(self):
def test_standardize_raises_when_no_observations(self) -> None:
tf = Standardize(m=1)
with self.assertRaisesRegex(
ValueError, "Can't standardize with no observations."
):
tf(torch.zeros(0, 1, device=self.device), None)

def test_standardize(self) -> None:
# test error on incompatible dim
tf = Standardize(m=1)
with self.assertRaisesRegex(
Expand All @@ -134,9 +141,10 @@ def test_standardize(self):
ms = (1, 2)
batch_shapes = (torch.Size(), torch.Size([2]))
dtypes = (torch.float, torch.double)
ns = [1, 3]

# test transform, untransform, untransform_posterior
for m, batch_shape, dtype in itertools.product(ms, batch_shapes, dtypes):
for m, batch_shape, dtype, n in itertools.product(ms, batch_shapes, dtypes, ns):
# test init
tf = Standardize(m=m, batch_shape=batch_shape)
self.assertTrue(tf.training)
Expand All @@ -148,7 +156,7 @@ def test_standardize(self):
# no observation noise
with torch.random.fork_rng():
torch.manual_seed(0)
Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
Y = torch.rand(*batch_shape, n, m, device=self.device, dtype=dtype)
Y_tf, Yvar_tf = tf(Y, None)
self.assertTrue(tf.training)
self.assertTrue(torch.all(Y_tf.mean(dim=-2).abs() < 1e-4))
Expand All @@ -171,14 +179,16 @@ def test_standardize(self):
tf = Standardize(m=m, batch_shape=batch_shape)
with torch.random.fork_rng():
torch.manual_seed(0)
Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype)
Y = torch.rand(*batch_shape, n, m, device=self.device, dtype=dtype)
Yvar = 1e-8 + torch.rand(
*batch_shape, 3, m, device=self.device, dtype=dtype
*batch_shape, n, m, device=self.device, dtype=dtype
)
Y_tf, Yvar_tf = tf(Y, Yvar)
self.assertTrue(tf.training)
self.assertTrue(torch.all(Y_tf.mean(dim=-2).abs() < 1e-4))
Yvar_tf_expected = Yvar / Y.std(dim=-2, keepdim=True) ** 2
Yvar_tf_expected = (
Yvar if n == 1 else Yvar / Y.std(dim=-2, keepdim=True) ** 2
)
self.assertAllClose(Yvar_tf, Yvar_tf_expected)
tf.eval()
self.assertFalse(tf.training)
Expand All @@ -190,7 +200,7 @@ def test_standardize(self):
for interleaved, lazy in itertools.product((True, False), (True, False)):
if m == 1 and interleaved: # interleave has no meaning for m=1
continue
shape = batch_shape + torch.Size([3, m])
shape = batch_shape + torch.Size([n, m])
posterior = _get_test_posterior(
shape,
device=self.device,
Expand All @@ -216,12 +226,12 @@ def test_standardize(self):
# Untransform BlockDiagLinearOperator.
if m > 1:
base_lcv = DiagLinearOperator(
torch.rand(*batch_shape, m, 3, device=self.device, dtype=dtype)
torch.rand(*batch_shape, m, n, device=self.device, dtype=dtype)
)
lcv = BlockDiagLinearOperator(base_lcv)
mvn = MultitaskMultivariateNormal(
mean=torch.rand(
*batch_shape, 3, m, device=self.device, dtype=dtype
*batch_shape, n, m, device=self.device, dtype=dtype
),
covariance_matrix=lcv,
interleaved=False,
Expand All @@ -240,7 +250,7 @@ def test_standardize(self):
samples2 = p_utf.rsample(sample_shape=torch.Size([4, 2]))
self.assertEqual(
samples2.shape,
torch.Size([4, 2]) + batch_shape + torch.Size([3, m]),
torch.Size([4, 2]) + batch_shape + torch.Size([n, m]),
)

# untransform_posterior for non-GPyTorch posterior
Expand All @@ -252,7 +262,7 @@ def test_standardize(self):
)
p_utf2 = tf.untransform_posterior(posterior2)
self.assertEqual(p_utf2.device.type, self.device.type)
self.assertTrue(p_utf2.dtype == dtype)
self.assertEqual(p_utf2.dtype, dtype)
mean_expected = tf.means + tf.stdvs * posterior.mean
variance_expected = tf.stdvs**2 * posterior.variance
self.assertAllClose(p_utf2.mean, mean_expected)
Expand Down
31 changes: 23 additions & 8 deletions test/models/utils/test_assorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,26 +141,41 @@ def test_check_min_max_scaling(self):
def test_check_standardization(self):
# Ensure that it is not filtered out.
warnings.filterwarnings("always", category=InputDataWarning)
torch.manual_seed(0)
Y = torch.randn(3, 4, 2)
# check standardized input
Yst = (Y - Y.mean(dim=-2, keepdim=True)) / Y.std(dim=-2, keepdim=True)
with warnings.catch_warnings(record=True) as ws:
check_standardization(Y=Yst)
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
check_standardization(Y=Yst, raise_on_fail=True)
# check nonzero mean

# check standardized input with one observation
y = torch.zeros((3, 1, 2))
with warnings.catch_warnings(record=True) as ws:
check_standardization(Y=y)
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
check_standardization(Y=y, raise_on_fail=True)

# check nonzero mean for case where >= 2 observations per batch
msg_more_than_1_obs = r"Data is not standardized \(std ="
with self.assertWarnsRegex(InputDataWarning, msg_more_than_1_obs):
check_standardization(Y=Yst + 1)
self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws))
self.assertTrue(any("not standardized" in str(w.message) for w in ws))
with self.assertRaises(InputDataError):
with self.assertRaisesRegex(InputDataError, msg_more_than_1_obs):
check_standardization(Y=Yst + 1, raise_on_fail=True)

# check nonzero mean for case where < 2 observations per batch
msg_one_obs = r"Data is not standardized \(mean ="
y = torch.ones((3, 1, 2), dtype=torch.float32)
with self.assertWarnsRegex(InputDataWarning, msg_one_obs):
check_standardization(Y=y)
with self.assertRaisesRegex(InputDataError, msg_one_obs):
check_standardization(Y=y, raise_on_fail=True)

# check non-unit variance
with warnings.catch_warnings(record=True) as ws:
with self.assertWarnsRegex(InputDataWarning, msg_more_than_1_obs):
check_standardization(Y=Yst * 2)
self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws))
self.assertTrue(any("not standardized" in str(w.message) for w in ws))
with self.assertRaises(InputDataError):
with self.assertRaisesRegex(InputDataError, msg_more_than_1_obs):
check_standardization(Y=Yst * 2, raise_on_fail=True)

def test_validate_input_scaling(self):
Expand Down

0 comments on commit f79d608

Please sign in to comment.