-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ResNet: Unstable Attributions #194
Comments
Hi @chr5tphr is there any news on this bug? I noticed that when setting import torch
import torch.nn as nn
from torchvision.models import resnet34
from zennit.torchvision import ResNetCanonizer
from zennit.composites import EpsilonPlusFlat
import matplotlib.pyplot as plt
model = resnet34(weights=None)
# create a composite, specifying the canonizers
composite = EpsilonPlusFlat(canonizers=[ResNetCanonizer()], zero_params='bias')
target = torch.eye(1000)[[437]]
input_data = torch.rand(1, 3, 224, 224)
input_data.requires_grad = True
with composite.context(model) as modified_model:
output = modified_model(input_data)
attribution, = torch.autograd.grad(output, input_data, target)
relevance = attribution.cpu().sum(1).squeeze(0)
if torch.any(relevance < 0):
print(relevance[relevance < 0].sum())
print(relevance[relevance > 0].sum())
print(relevance.sum())
plt.imshow(relevance.numpy())
plt.show() If you maybe don't have the time to deal with this issue could you maybe point me in the right direction so I could try to fix it since it is critical for me to fix this issue in order to continue with my research. |
Hey @MikiFER I am not sure what you are experiencing is concerning this issue, as the relevance still sums to 1. Here's a little snippet to check the model relevance in detail. I have added an extra rule to switch off the residual branch to circumvent the LRP-instabilities discussed in #148. The canonizer is simply ineffective for vgg11, which does not need one (although, since ResNetCanonizer includes MergeBatchNorm, vgg11bn would also work): Snippet to check relevancesfrom itertools import islice
import torch
import torch.nn as nn
from torchvision.models import resnet34, vgg11
from zennit.torchvision import ResNetCanonizer
from zennit.composites import EpsilonPlusFlat
from zennit.core import Hook
from zennit.layer import Sum
class SumSingle(Hook):
def __init__(self, dim=1):
super().__init__()
self.dim = dim
def backward(self, module, grad_input, grad_output):
elems = [torch.zeros_like(grad_output[0])] * (grad_input[0].shape[-1])
elems[self.dim] = grad_output[0]
return (torch.stack(elems, dim=-1),)
def store_hook(module, input, output):
module.output = output
output.retain_grad()
models = {'resnet34': resnet34, 'vgg11': vgg11}
for model_name, model_fn in models.items():
torch.manual_seed(0xdeadbeef + 3)
model = model_fn(weights=None)
model.eval()
# create a composite, specifying the canonizers
composite = EpsilonPlusFlat(
layer_map=[(Sum, SumSingle(1))],
canonizers=[ResNetCanonizer()],
zero_params='bias'
)
target = torch.eye(1000)[[437]]
input_data = torch.rand(1, 3, 224, 224)
input_data.requires_grad = True
with composite.context(model) as modified_model:
handles = [
module.register_forward_hook(store_hook)
for module in model.modules()
if not list(islice(module.children(), 1))
]
output = modified_model(input_data)
attribution, = torch.autograd.grad(output, input_data, target)
relevance = attribution.cpu().sum(1).squeeze(0)
labels = [('input', 'input', attribution)] + [
(name, type(module).__name__, module.output.grad)
for name, module in model.named_modules()
if hasattr(module, 'output')
]
maxname, maxclsname = [max(len(obj[i]) for obj in labels) for i in (0, 1)]
print(f'\nModel: {model_name}')
for name, clsname, grad in labels:
print(
f' {name:<{maxname}s} ({clsname:<{maxclsname}s}): '
f'min: {grad.min():+.7f}, '
f'max: {grad.max():+.7f}, '
f'sum: {grad.sum():+.7f}'
) And this is the output I get on Output on `0.5.1`
If you are on Maybe skipping the residual branch already fixes your issue? |
Hi @chr5tphr thank you for your response. I have since discovered that I had a misunderstanding about the EpsilonPlusFlat composite and have decided that pure Alpha1Beta0 rule is actually what I need because I want to obtain relevance map with only positive influence in the input. |
With the introduction of #185 , ResNet18 attributions result in negative attribution sums in the input layer, leading to bad attributions. Although #185 increased the stability of the attribution sums for ResNet, the previous instability seems to have inflated the positive parts of the attributions, circumventing this problem pre #185.
This seems to be related to leaking attributions (#193 ) combined with skip connections that can cause negative attributions.
A quickfix for
EpsilonGammaBox
is to use a slightly higher gamma value.The text was updated successfully, but these errors were encountered: