-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to add support to DenseNet from torchvision #165
Comments
Hey @mstaczek the truth is, DenseNet is quite tricky with LRP in the sense that the DenseLayers within the DenseBlocks end with a linear layer without an activation, and start with a BatchNorm. This means that, with the residual connections, multiple BatchNorm layers are connected to multiple Linear layers, so they cannot be merged into the linear layer as it is normally done. Image from the paper: Text representation of the torchvision model: Expand text
Current StateUsing DenseNet without canonizers does not work correctly: You can use the Epsilon rule in BatchNorm layers for slightly better results: Here's some code to produce heatmaps with densenet121: Codeimport os
import torch
from torchvision.models import densenet121
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop
from PIL import Image
from zennit.attribution import Gradient
from zennit.composites import EpsilonPlusFlat, EpsilonGammaBox
from zennit.types import BatchNorm
from zennit.image import imgify
from zennit.rules import Epsilon
fname = 'dornbusch-lighthouse.jpg'
if not os.path.exists(fname):
torch.hub.download_url_to_file(
'https://upload.wikimedia.org/wikipedia/commons/thumb/8/8b/2006_09_06_180_Leuchtturm.jpg/640px-2006_09_06_181_Leuchtturm.jpg',
fname,
)
# define the base image transform
transform_img = Compose([
Resize(256),
CenterCrop(224),
])
# define the normalization transform
transform_norm = Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
# define the full tensor transform
transform = Compose([
transform_img,
ToTensor(),
transform_norm,
])
# load the image
image = Image.open('dornbusch-lighthouse.jpg')
# transform the PIL image and insert a batch-dimension
data = transform(image)[None]
model = densenet121(weights='DEFAULT').eval()
composite = EpsilonGammaBox(low=-3., high=3., layer_map=[(BatchNorm, Epsilon())])
input = data.clone().requires_grad_(True)
target = torch.eye(1000)[[437]]
with Gradient(model, composite) as attributor:
output, relevance = attributor(input, target)
transform_img(image).save('original.png')
imgify(relevance[0].detach().sum(0), cmap='bwr', symmetric=True).save('densenet121.png') Implementing a DenseNet CanonizerUltimately, a canonizer needs to be implemented also for DenseNet, due to its problematic BatchNorms. There are a few settings of BatchNorms that need different handling:
There may be things that I overlooked, but LRP for DenseNet, is, by the design of LRP, currently quite a challenge to get right. It needs careful thinking in order to be done as implied by the definition of LRP. I will try to discuss this in our Lab and see if there's a better solution, but maybe until then you can try the Epsilon rule for BatchNorm layers. |
Wow, I did not expect it to be such a challenge! Thank you for the explanation and sample heatmaps. They really help to convince that a custom canonizer is necessary for DenseNets. I will think about implementing it after reading more about DenseNet, it's blocks and LRP. |
I wanted to use LRP with DesneNet121 from torchvision. So far, in zennit.torchvision I found canonizers for ResNet and VGG and I wonder if I may use them (to get some results) or I need to write my own custom canonizer (because the network has some new layers that were not covered by previous canonizers?).
Thanks for your help!
The text was updated successfully, but these errors were encountered: