From f79d6080cc2be93381727c7cd7a0bd63f1bc96b5 Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Tue, 11 Jun 2024 12:36:45 -0700 Subject: [PATCH] Clarify behavior on standard deviations with <1 degree of freedom and silence some unit test warnings (#2357) Summary: ## Motivation Unit tests were producing a lot of warnings about taking standard deviations across fewer than 2 observations, and it was not clear to me if these warnings were legitimate in context. * For checking the standardization of input data, no longer check the standard deviation if there is just one observation. * For the standardize input transform, explicitly set standard deviations to 1 when there is only one observation. This actually matches the legacy behavior, but previously it wasn't clear because the standard deviation would become NaN before being corrected to 1. * Error on attempting to standardize 0 observations. This never worked so now it is more clear. Pull Request resolved: https://github.com/pytorch/botorch/pull/2357 Test Plan: Added units ## Related PRs Reviewed By: Balandat Differential Revision: D57931412 Pulled By: esantorella fbshipit-source-id: 36a9c81a950a0b92749673fdd22aec62b45aaae9 --- botorch/models/transforms/outcome.py | 10 +++++++- botorch/models/utils/assorted.py | 33 ++++++++++++++++++-------- botorch/utils/testing.py | 2 +- test/models/transforms/test_outcome.py | 32 ++++++++++++++++--------- test/models/utils/test_assorted.py | 31 +++++++++++++++++------- 5 files changed, 77 insertions(+), 31 deletions(-) diff --git a/botorch/models/transforms/outcome.py b/botorch/models/transforms/outcome.py index 9e3ec07853..0bcd6a6a78 100644 --- a/botorch/models/transforms/outcome.py +++ b/botorch/models/transforms/outcome.py @@ -286,7 +286,15 @@ def forward( f"Wrong output dimension. Y.size(-1) is {Y.size(-1)}; expected " f"{self._m}." ) - stdvs = Y.std(dim=-2, keepdim=True) + if Y.shape[-2] < 1: + raise ValueError(f"Can't standardize with no observations. {Y.shape=}.") + + elif Y.shape[-2] == 1: + stdvs = torch.ones( + (*Y.shape[:-2], 1, Y.shape[-1]), dtype=Y.dtype, device=Y.device + ) + else: + stdvs = Y.std(dim=-2, keepdim=True) stdvs = stdvs.where(stdvs >= self._min_stdv, torch.full_like(stdvs, 1.0)) means = Y.mean(dim=-2, keepdim=True) if self._outputs is not None: diff --git a/botorch/models/utils/assorted.py b/botorch/models/utils/assorted.py index bb7acb2ffd..0272d8057f 100644 --- a/botorch/models/utils/assorted.py +++ b/botorch/models/utils/assorted.py @@ -171,7 +171,7 @@ def check_min_max_scaling( ) if raise_on_fail: raise InputDataError(msg) - warnings.warn(msg, InputDataWarning) + warnings.warn(msg, InputDataWarning, stacklevel=2) def check_standardization( @@ -191,15 +191,28 @@ def check_standardization( raise_on_fail: If True, raise an exception instead of a warning. """ with torch.no_grad(): - Ymean, Ystd = torch.mean(Y, dim=-2), torch.std(Y, dim=-2) - if torch.abs(Ymean).max() > atol_mean or torch.abs(Ystd - 1).max() > atol_std: - msg = ( - f"Input data is not standardized (mean = {Ymean}, std = {Ystd}). " - "Please consider scaling the input to zero mean and unit variance." - ) - if raise_on_fail: - raise InputDataError(msg) - warnings.warn(msg, InputDataWarning) + Ymean = torch.mean(Y, dim=-2) + mean_not_zero = torch.abs(Ymean).max() > atol_mean + if Y.shape[-2] <= 1: + if mean_not_zero: + msg = ( + f"Data is not standardized (mean = {Ymean}). " + "Please consider scaling the input to zero mean and unit variance." + ) + if raise_on_fail: + raise InputDataError(msg) + warnings.warn(msg, InputDataWarning, stacklevel=2) + else: + Ystd = torch.std(Y, dim=-2) + std_not_one = torch.abs(Ystd - 1).max() > atol_std + if mean_not_zero or std_not_one: + msg = ( + f"Data is not standardized (std = {Ystd}, mean = {Ymean}). " + "Please consider scaling the input to zero mean and unit variance." + ) + if raise_on_fail: + raise InputDataError(msg) + warnings.warn(msg, InputDataWarning, stacklevel=2) def validate_input_scaling( diff --git a/botorch/utils/testing.py b/botorch/utils/testing.py index 81bb7d15a4..4d159f7db1 100644 --- a/botorch/utils/testing.py +++ b/botorch/utils/testing.py @@ -60,7 +60,7 @@ def setUp(self, suppress_input_warnings: bool = True) -> None: ) warnings.filterwarnings( "ignore", - message="Input data is not standardized.", + message="Data is not standardized.", category=InputDataWarning, ) warnings.filterwarnings( diff --git a/test/models/transforms/test_outcome.py b/test/models/transforms/test_outcome.py index d4acc388eb..be49b46cb4 100644 --- a/test/models/transforms/test_outcome.py +++ b/test/models/transforms/test_outcome.py @@ -115,7 +115,14 @@ def test_is_linear(self) -> None: ) self.assertEqual(posterior_is_gpt, transform._is_linear) - def test_standardize(self): + def test_standardize_raises_when_no_observations(self) -> None: + tf = Standardize(m=1) + with self.assertRaisesRegex( + ValueError, "Can't standardize with no observations." + ): + tf(torch.zeros(0, 1, device=self.device), None) + + def test_standardize(self) -> None: # test error on incompatible dim tf = Standardize(m=1) with self.assertRaisesRegex( @@ -134,9 +141,10 @@ def test_standardize(self): ms = (1, 2) batch_shapes = (torch.Size(), torch.Size([2])) dtypes = (torch.float, torch.double) + ns = [1, 3] # test transform, untransform, untransform_posterior - for m, batch_shape, dtype in itertools.product(ms, batch_shapes, dtypes): + for m, batch_shape, dtype, n in itertools.product(ms, batch_shapes, dtypes, ns): # test init tf = Standardize(m=m, batch_shape=batch_shape) self.assertTrue(tf.training) @@ -148,7 +156,7 @@ def test_standardize(self): # no observation noise with torch.random.fork_rng(): torch.manual_seed(0) - Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype) + Y = torch.rand(*batch_shape, n, m, device=self.device, dtype=dtype) Y_tf, Yvar_tf = tf(Y, None) self.assertTrue(tf.training) self.assertTrue(torch.all(Y_tf.mean(dim=-2).abs() < 1e-4)) @@ -171,14 +179,16 @@ def test_standardize(self): tf = Standardize(m=m, batch_shape=batch_shape) with torch.random.fork_rng(): torch.manual_seed(0) - Y = torch.rand(*batch_shape, 3, m, device=self.device, dtype=dtype) + Y = torch.rand(*batch_shape, n, m, device=self.device, dtype=dtype) Yvar = 1e-8 + torch.rand( - *batch_shape, 3, m, device=self.device, dtype=dtype + *batch_shape, n, m, device=self.device, dtype=dtype ) Y_tf, Yvar_tf = tf(Y, Yvar) self.assertTrue(tf.training) self.assertTrue(torch.all(Y_tf.mean(dim=-2).abs() < 1e-4)) - Yvar_tf_expected = Yvar / Y.std(dim=-2, keepdim=True) ** 2 + Yvar_tf_expected = ( + Yvar if n == 1 else Yvar / Y.std(dim=-2, keepdim=True) ** 2 + ) self.assertAllClose(Yvar_tf, Yvar_tf_expected) tf.eval() self.assertFalse(tf.training) @@ -190,7 +200,7 @@ def test_standardize(self): for interleaved, lazy in itertools.product((True, False), (True, False)): if m == 1 and interleaved: # interleave has no meaning for m=1 continue - shape = batch_shape + torch.Size([3, m]) + shape = batch_shape + torch.Size([n, m]) posterior = _get_test_posterior( shape, device=self.device, @@ -216,12 +226,12 @@ def test_standardize(self): # Untransform BlockDiagLinearOperator. if m > 1: base_lcv = DiagLinearOperator( - torch.rand(*batch_shape, m, 3, device=self.device, dtype=dtype) + torch.rand(*batch_shape, m, n, device=self.device, dtype=dtype) ) lcv = BlockDiagLinearOperator(base_lcv) mvn = MultitaskMultivariateNormal( mean=torch.rand( - *batch_shape, 3, m, device=self.device, dtype=dtype + *batch_shape, n, m, device=self.device, dtype=dtype ), covariance_matrix=lcv, interleaved=False, @@ -240,7 +250,7 @@ def test_standardize(self): samples2 = p_utf.rsample(sample_shape=torch.Size([4, 2])) self.assertEqual( samples2.shape, - torch.Size([4, 2]) + batch_shape + torch.Size([3, m]), + torch.Size([4, 2]) + batch_shape + torch.Size([n, m]), ) # untransform_posterior for non-GPyTorch posterior @@ -252,7 +262,7 @@ def test_standardize(self): ) p_utf2 = tf.untransform_posterior(posterior2) self.assertEqual(p_utf2.device.type, self.device.type) - self.assertTrue(p_utf2.dtype == dtype) + self.assertEqual(p_utf2.dtype, dtype) mean_expected = tf.means + tf.stdvs * posterior.mean variance_expected = tf.stdvs**2 * posterior.variance self.assertAllClose(p_utf2.mean, mean_expected) diff --git a/test/models/utils/test_assorted.py b/test/models/utils/test_assorted.py index fd7bd7d03c..459893363e 100644 --- a/test/models/utils/test_assorted.py +++ b/test/models/utils/test_assorted.py @@ -141,6 +141,7 @@ def test_check_min_max_scaling(self): def test_check_standardization(self): # Ensure that it is not filtered out. warnings.filterwarnings("always", category=InputDataWarning) + torch.manual_seed(0) Y = torch.randn(3, 4, 2) # check standardized input Yst = (Y - Y.mean(dim=-2, keepdim=True)) / Y.std(dim=-2, keepdim=True) @@ -148,19 +149,33 @@ def test_check_standardization(self): check_standardization(Y=Yst) self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws)) check_standardization(Y=Yst, raise_on_fail=True) - # check nonzero mean + + # check standardized input with one observation + y = torch.zeros((3, 1, 2)) with warnings.catch_warnings(record=True) as ws: + check_standardization(Y=y) + self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws)) + check_standardization(Y=y, raise_on_fail=True) + + # check nonzero mean for case where >= 2 observations per batch + msg_more_than_1_obs = r"Data is not standardized \(std =" + with self.assertWarnsRegex(InputDataWarning, msg_more_than_1_obs): check_standardization(Y=Yst + 1) - self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws)) - self.assertTrue(any("not standardized" in str(w.message) for w in ws)) - with self.assertRaises(InputDataError): + with self.assertRaisesRegex(InputDataError, msg_more_than_1_obs): check_standardization(Y=Yst + 1, raise_on_fail=True) + + # check nonzero mean for case where < 2 observations per batch + msg_one_obs = r"Data is not standardized \(mean =" + y = torch.ones((3, 1, 2), dtype=torch.float32) + with self.assertWarnsRegex(InputDataWarning, msg_one_obs): + check_standardization(Y=y) + with self.assertRaisesRegex(InputDataError, msg_one_obs): + check_standardization(Y=y, raise_on_fail=True) + # check non-unit variance - with warnings.catch_warnings(record=True) as ws: + with self.assertWarnsRegex(InputDataWarning, msg_more_than_1_obs): check_standardization(Y=Yst * 2) - self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws)) - self.assertTrue(any("not standardized" in str(w.message) for w in ws)) - with self.assertRaises(InputDataError): + with self.assertRaisesRegex(InputDataError, msg_more_than_1_obs): check_standardization(Y=Yst * 2, raise_on_fail=True) def test_validate_input_scaling(self):