From 5985f288c63dc5c1d3969c08a38497cf86d8fc09 Mon Sep 17 00:00:00 2001 From: Zach Carmichael Date: Wed, 11 Dec 2024 08:56:14 -0800 Subject: [PATCH] Fix flaky test_softmax_classification_batch_multi_target test case by increasing precision Summary: The test case test_softmax_classification_batch_multi_target is flaky and can fail due to floating point error. This diff changes the test case to use doubles instead of single floats. Differential Revision: D67071680 --- captum/attr/_core/deep_lift.py | 2 -- tests/attr/test_deeplift_classification.py | 10 +++++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/captum/attr/_core/deep_lift.py b/captum/attr/_core/deep_lift.py index 6e6beb7b5..d7997195e 100644 --- a/captum/attr/_core/deep_lift.py +++ b/captum/attr/_core/deep_lift.py @@ -523,8 +523,6 @@ def pre_hook( additional_args, 2, ExpansionTypes.repeat ), ) - # pyre-fixme[60]: Concatenation not yet support for multiple - # variadic tuples: `*baseline_input_tsr, *expanded_additional_args`. return (*baseline_input_tsr, *expanded_additional_args) return baseline_input_tsr diff --git a/tests/attr/test_deeplift_classification.py b/tests/attr/test_deeplift_classification.py index 1c5e49387..85a9db00d 100644 --- a/tests/attr/test_deeplift_classification.py +++ b/tests/attr/test_deeplift_classification.py @@ -65,9 +65,13 @@ def test_softmax_classification_batch_zero_baseline(self) -> None: def test_softmax_classification_batch_multi_target(self) -> None: num_in = 40 - inputs = torch.arange(0.0, num_in * 3.0, requires_grad=True).reshape(3, num_in) - baselines = torch.arange(1.0, num_in + 1).reshape(1, num_in) - model = SoftmaxDeepLiftModel(num_in, 20, 10) + inputs = ( + torch.arange(0.0, num_in * 3.0, requires_grad=True) + .reshape(3, num_in) + .double() + ) + baselines = torch.arange(1.0, num_in + 1).reshape(1, num_in).double() + model = SoftmaxDeepLiftModel(num_in, 20, 10).double() dl = DeepLift(model) self.softmax_classification(