Skip to content

Commit

Permalink
Fix flaky test_softmax_classification_batch_multi_target test case by…
Browse files Browse the repository at this point in the history
… increasing precision

Summary: The test case test_softmax_classification_batch_multi_target is flaky and can fail due to floating point error. This diff changes the test case to use doubles instead of single floats.

Differential Revision: D67071680
  • Loading branch information
craymichael authored and facebook-github-bot committed Dec 11, 2024
1 parent 92d82df commit 5985f28
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
2 changes: 0 additions & 2 deletions captum/attr/_core/deep_lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,8 +523,6 @@ def pre_hook(
additional_args, 2, ExpansionTypes.repeat
),
)
# pyre-fixme[60]: Concatenation not yet support for multiple
# variadic tuples: `*baseline_input_tsr, *expanded_additional_args`.
return (*baseline_input_tsr, *expanded_additional_args)
return baseline_input_tsr

Expand Down
10 changes: 7 additions & 3 deletions tests/attr/test_deeplift_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,13 @@ def test_softmax_classification_batch_zero_baseline(self) -> None:

def test_softmax_classification_batch_multi_target(self) -> None:
num_in = 40
inputs = torch.arange(0.0, num_in * 3.0, requires_grad=True).reshape(3, num_in)
baselines = torch.arange(1.0, num_in + 1).reshape(1, num_in)
model = SoftmaxDeepLiftModel(num_in, 20, 10)
inputs = (
torch.arange(0.0, num_in * 3.0, requires_grad=True)
.reshape(3, num_in)
.double()
)
baselines = torch.arange(1.0, num_in + 1).reshape(1, num_in).double()
model = SoftmaxDeepLiftModel(num_in, 20, 10).double()
dl = DeepLift(model)

self.softmax_classification(
Expand Down

0 comments on commit 5985f28

Please sign in to comment.