Skip to content

Commit

Permalink
Fix deeplift mypy error (#1459)
Browse files Browse the repository at this point in the history
Summary:

Currently, Captum OSS tests are failing due to mypy failures (likely from new version) in DeepLift test cases. Adds fix for type failure caused by different signature between DeepLift and DeepLiftShap.

Differential Revision: D67538043
  • Loading branch information
Vivek Miglani authored and facebook-github-bot committed Dec 20, 2024
1 parent 600dcb3 commit b14b423
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tests/attr/test_deeplift_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pyre-unsafe

from typing import Union
from typing import TypeVar, Union

import torch
from captum._utils.typing import TargetType
Expand All @@ -21,6 +21,8 @@
from torch import Tensor
from torch.nn import Module

DeepLiftAttrMethod = TypeVar("DeepLiftAttrMethod", DeepLift, DeepLiftShap)


class Test(BaseTest):
def test_sigmoid_classification(self) -> None:
Expand Down Expand Up @@ -155,7 +157,7 @@ def test_convnet_with_maxpool1d_large_baselines(self) -> None:
def softmax_classification(
self,
model: Module,
attr_method: Union[DeepLift, DeepLiftShap],
attr_method: DeepLiftAttrMethod,
input: Tensor,
baselines: Union[float, int, Tensor],
target: TargetType,
Expand Down

0 comments on commit b14b423

Please sign in to comment.