From 543352c43198b9e8b111ca7801f5a8d136c95b66 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 20 Nov 2024 07:24:15 -0800 Subject: [PATCH] Prepare for "Fix type-safety of `torch.nn.Module` instances": fbcode/p* Summary: See D52890934 Differential Revision: D66235323 --- captum/attr/_core/feature_ablation.py | 5 ++++- captum/attr/_core/occlusion.py | 2 ++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/captum/attr/_core/feature_ablation.py b/captum/attr/_core/feature_ablation.py index dad8f4756..e5e60bb46 100644 --- a/captum/attr/_core/feature_ablation.py +++ b/captum/attr/_core/feature_ablation.py @@ -808,7 +808,10 @@ def _construct_ablated_input( current_mask = current_mask.to(expanded_input.device) assert baseline is not None, "baseline must be provided" ablated_tensor = ( - expanded_input * (1 - current_mask).to(expanded_input.dtype) + expanded_input + * (1 - current_mask).to(expanded_input.dtype) + # pyre-fixme[58]: `*` is not supported for operand types `Union[None, float, + # Tensor]` and `Tensor`. ) + (baseline * current_mask.to(expanded_input.dtype)) return ablated_tensor, current_mask diff --git a/captum/attr/_core/occlusion.py b/captum/attr/_core/occlusion.py index fe5105c96..f6bfcbe8a 100644 --- a/captum/attr/_core/occlusion.py +++ b/captum/attr/_core/occlusion.py @@ -323,6 +323,8 @@ def _construct_ablated_input( torch.ones(1, dtype=torch.long, device=expanded_input.device) - input_mask ).to(expanded_input.dtype) + # pyre-fixme[58]: `*` is not supported for operand types `Union[None, float, + # Tensor]` and `Tensor`. ) + (baseline * input_mask.to(expanded_input.dtype)) return ablated_tensor, input_mask