Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ResNet: Unstable Attributions #194

Open
chr5tphr opened this issue Aug 9, 2023 · 3 comments
Open

ResNet: Unstable Attributions #194

chr5tphr opened this issue Aug 9, 2023 · 3 comments
Labels
bug Something isn't working

Comments

@chr5tphr
Copy link
Owner

chr5tphr commented Aug 9, 2023

With the introduction of #185 , ResNet18 attributions result in negative attribution sums in the input layer, leading to bad attributions. Although #185 increased the stability of the attribution sums for ResNet, the previous instability seems to have inflated the positive parts of the attributions, circumventing this problem pre #185.

This seems to be related to leaking attributions (#193 ) combined with skip connections that can cause negative attributions.

A quickfix for EpsilonGammaBox is to use a slightly higher gamma value.

@MikiFER
Copy link

MikiFER commented Oct 27, 2023

Hi @chr5tphr is there any news on this bug? I noticed that when setting zero_params='bias' for EpsilonPlusFlat composite total attribution sums to ~1 but sometimes negative attribution may occur and that negative attribution sometimes is small but sometimes it accounts to ~-1 and more so its significant and positive part of the attribution in that case accounts to almost ~2 to negate the effect.
Here is code I used to replicate this issue (multiple runs may be required):

import torch
import torch.nn as nn
from torchvision.models import resnet34
from zennit.torchvision import ResNetCanonizer
from zennit.composites import EpsilonPlusFlat
import matplotlib.pyplot as plt

model = resnet34(weights=None)

# create a composite, specifying the canonizers
composite = EpsilonPlusFlat(canonizers=[ResNetCanonizer()], zero_params='bias')
target = torch.eye(1000)[[437]]
input_data = torch.rand(1, 3, 224, 224)
input_data.requires_grad = True

with composite.context(model) as modified_model:
    output = modified_model(input_data)
    attribution, = torch.autograd.grad(output, input_data, target)

relevance = attribution.cpu().sum(1).squeeze(0)

if torch.any(relevance < 0):
    print(relevance[relevance < 0].sum())
    print(relevance[relevance > 0].sum())
    print(relevance.sum())
plt.imshow(relevance.numpy())
plt.show()

If you maybe don't have the time to deal with this issue could you maybe point me in the right direction so I could try to fix it since it is critical for me to fix this issue in order to continue with my research.

@chr5tphr
Copy link
Owner Author

Hey @MikiFER

I am not sure what you are experiencing is concerning this issue, as the relevance still sums to 1.
The bug may be a little bit elusive, but in general it is okay if there is negative relevance (for EpsilonPlusFlat anyways).
You should be fine on 0.5.1, the bug referred to in this issue is only on master, caused by #185.
Did you use master?

Here's a little snippet to check the model relevance in detail. I have added an extra rule to switch off the residual branch to circumvent the LRP-instabilities discussed in #148. The canonizer is simply ineffective for vgg11, which does not need one (although, since ResNetCanonizer includes MergeBatchNorm, vgg11bn would also work):

Snippet to check relevances
from itertools import islice
import torch
import torch.nn as nn
from torchvision.models import resnet34, vgg11
from zennit.torchvision import ResNetCanonizer
from zennit.composites import EpsilonPlusFlat
from zennit.core import Hook
from zennit.layer import Sum


class SumSingle(Hook):
    def __init__(self, dim=1):
        super().__init__()
        self.dim = dim

    def backward(self, module, grad_input, grad_output):
        elems = [torch.zeros_like(grad_output[0])] * (grad_input[0].shape[-1])
        elems[self.dim] = grad_output[0]
        return (torch.stack(elems, dim=-1),)


def store_hook(module, input, output):
    module.output = output
    output.retain_grad()


models = {'resnet34': resnet34, 'vgg11': vgg11}

for model_name, model_fn in models.items():
    torch.manual_seed(0xdeadbeef + 3)
    model = model_fn(weights=None)
    model.eval()

    # create a composite, specifying the canonizers
    composite = EpsilonPlusFlat(
        layer_map=[(Sum, SumSingle(1))],
        canonizers=[ResNetCanonizer()],
        zero_params='bias'
    )
    target = torch.eye(1000)[[437]]
    input_data = torch.rand(1, 3, 224, 224)
    input_data.requires_grad = True

    with composite.context(model) as modified_model:
        handles = [
            module.register_forward_hook(store_hook)
            for module in model.modules()
            if not list(islice(module.children(), 1))
        ]
        output = modified_model(input_data)
        attribution, = torch.autograd.grad(output, input_data, target)

    relevance = attribution.cpu().sum(1).squeeze(0)

    labels = [('input', 'input', attribution)] + [
        (name, type(module).__name__, module.output.grad)
        for name, module in model.named_modules()
        if hasattr(module, 'output')
    ]
    maxname, maxclsname = [max(len(obj[i]) for obj in labels) for i in (0, 1)]

    print(f'\nModel: {model_name}')
    for name, clsname, grad in labels:
        print(
            f'  {name:<{maxname}s} ({clsname:<{maxclsname}s}): '
            f'min: {grad.min():+.7f}, '
            f'max: {grad.max():+.7f}, '
            f'sum: {grad.sum():+.7f}'
        )

And this is the output I get on 0.5.1:

Output on `0.5.1`
Model: resnet34
  input                 (input            ): min: +0.0000003, max: +0.0000542, sum: +0.9999950
  conv1                 (Conv2d           ): min: +0.0000000, max: +0.0001266, sum: +0.9999949
  bn1                   (BatchNorm2d      ): min: +0.0000000, max: +0.0001266, sum: +0.9999949
  relu                  (ReLU             ): min: +0.0000000, max: +0.0001266, sum: +0.9999949
  maxpool               (MaxPool2d        ): min: +0.0000000, max: +0.0000341, sum: +0.9999950
  layer1.0.conv1        (Conv2d           ): min: +0.0000000, max: +0.0000796, sum: +0.9999954
  layer1.0.bn1          (BatchNorm2d      ): min: +0.0000000, max: +0.0000796, sum: +0.9999954
  layer1.0.relu         (ReLU             ): min: +0.0000000, max: +0.0000468, sum: +0.9999962
  layer1.0.conv2        (Conv2d           ): min: +0.0000000, max: +0.0000468, sum: +0.9999962
  layer1.0.bn2          (BatchNorm2d      ): min: +0.0000000, max: +0.0000468, sum: +0.9999962
  layer1.1.conv1        (Conv2d           ): min: +0.0000000, max: +0.0000834, sum: +0.9999965
  layer1.1.bn1          (BatchNorm2d      ): min: +0.0000000, max: +0.0000834, sum: +0.9999965
  layer1.1.relu         (ReLU             ): min: +0.0000000, max: +0.0000536, sum: +0.9999971
  layer1.1.conv2        (Conv2d           ): min: +0.0000000, max: +0.0000536, sum: +0.9999971
  layer1.1.bn2          (BatchNorm2d      ): min: +0.0000000, max: +0.0000536, sum: +0.9999971
  layer1.2.conv1        (Conv2d           ): min: +0.0000000, max: +0.0000902, sum: +0.9999974
  layer1.2.bn1          (BatchNorm2d      ): min: +0.0000000, max: +0.0000902, sum: +0.9999974
  layer1.2.relu         (ReLU             ): min: +0.0000000, max: +0.0001007, sum: +0.9999977
  layer1.2.conv2        (Conv2d           ): min: +0.0000000, max: +0.0001007, sum: +0.9999977
  layer1.2.bn2          (BatchNorm2d      ): min: +0.0000000, max: +0.0001007, sum: +0.9999977
  layer2.0.conv1        (Conv2d           ): min: +0.0000000, max: +0.0001177, sum: +0.9999980
  layer2.0.bn1          (BatchNorm2d      ): min: +0.0000000, max: +0.0001177, sum: +0.9999980
  layer2.0.relu         (ReLU             ): min: +0.0000000, max: +0.0001770, sum: +0.9999983
  layer2.0.conv2        (Conv2d           ): min: +0.0000000, max: +0.0001770, sum: +0.9999983
  layer2.0.bn2          (BatchNorm2d      ): min: +0.0000000, max: +0.0001770, sum: +0.9999983
  layer2.0.downsample.0 (Conv2d           ): min: +0.0000000, max: +0.0000000, sum: +0.0000000
  layer2.0.downsample.1 (BatchNorm2d      ): min: +0.0000000, max: +0.0000000, sum: +0.0000000
  layer2.1.conv1        (Conv2d           ): min: +0.0000000, max: +0.0001312, sum: +0.9999984
  layer2.1.bn1          (BatchNorm2d      ): min: +0.0000000, max: +0.0001312, sum: +0.9999984
  layer2.1.relu         (ReLU             ): min: +0.0000000, max: +0.0001304, sum: +0.9999986
  layer2.1.conv2        (Conv2d           ): min: +0.0000000, max: +0.0001304, sum: +0.9999986
  layer2.1.bn2          (BatchNorm2d      ): min: +0.0000000, max: +0.0001304, sum: +0.9999986
  layer2.2.conv1        (Conv2d           ): min: +0.0000000, max: +0.0001452, sum: +0.9999987
  layer2.2.bn1          (BatchNorm2d      ): min: +0.0000000, max: +0.0001452, sum: +0.9999987
  layer2.2.relu         (ReLU             ): min: +0.0000000, max: +0.0001383, sum: +0.9999988
  layer2.2.conv2        (Conv2d           ): min: +0.0000000, max: +0.0001383, sum: +0.9999988
  layer2.2.bn2          (BatchNorm2d      ): min: +0.0000000, max: +0.0001383, sum: +0.9999988
  layer2.3.conv1        (Conv2d           ): min: +0.0000000, max: +0.0001657, sum: +0.9999989
  layer2.3.bn1          (BatchNorm2d      ): min: +0.0000000, max: +0.0001657, sum: +0.9999989
  layer2.3.relu         (ReLU             ): min: +0.0000000, max: +0.0001711, sum: +0.9999990
  layer2.3.conv2        (Conv2d           ): min: +0.0000000, max: +0.0001711, sum: +0.9999990
  layer2.3.bn2          (BatchNorm2d      ): min: +0.0000000, max: +0.0001711, sum: +0.9999990
  layer3.0.conv1        (Conv2d           ): min: +0.0000000, max: +0.0003639, sum: +0.9999990
  layer3.0.bn1          (BatchNorm2d      ): min: +0.0000000, max: +0.0003639, sum: +0.9999990
  layer3.0.relu         (ReLU             ): min: +0.0000000, max: +0.0002886, sum: +0.9999992
  layer3.0.conv2        (Conv2d           ): min: +0.0000000, max: +0.0002886, sum: +0.9999992
  layer3.0.bn2          (BatchNorm2d      ): min: +0.0000000, max: +0.0002886, sum: +0.9999992
  layer3.0.downsample.0 (Conv2d           ): min: +0.0000000, max: +0.0000000, sum: +0.0000000
  layer3.0.downsample.1 (BatchNorm2d      ): min: +0.0000000, max: +0.0000000, sum: +0.0000000
  layer3.1.conv1        (Conv2d           ): min: +0.0000000, max: +0.0003569, sum: +0.9999993
  layer3.1.bn1          (BatchNorm2d      ): min: +0.0000000, max: +0.0003569, sum: +0.9999993
  layer3.1.relu         (ReLU             ): min: +0.0000000, max: +0.0003025, sum: +0.9999993
  layer3.1.conv2        (Conv2d           ): min: +0.0000000, max: +0.0003025, sum: +0.9999993
  layer3.1.bn2          (BatchNorm2d      ): min: +0.0000000, max: +0.0003025, sum: +0.9999993
  layer3.2.conv1        (Conv2d           ): min: +0.0000000, max: +0.0002958, sum: +0.9999993
  layer3.2.bn1          (BatchNorm2d      ): min: +0.0000000, max: +0.0002958, sum: +0.9999993
  layer3.2.relu         (ReLU             ): min: +0.0000000, max: +0.0002067, sum: +0.9999994
  layer3.2.conv2        (Conv2d           ): min: +0.0000000, max: +0.0002067, sum: +0.9999994
  layer3.2.bn2          (BatchNorm2d      ): min: +0.0000000, max: +0.0002067, sum: +0.9999994
  layer3.3.conv1        (Conv2d           ): min: +0.0000000, max: +0.0003350, sum: +0.9999993
  layer3.3.bn1          (BatchNorm2d      ): min: +0.0000000, max: +0.0003350, sum: +0.9999993
  layer3.3.relu         (ReLU             ): min: +0.0000000, max: +0.0002595, sum: +0.9999993
  layer3.3.conv2        (Conv2d           ): min: +0.0000000, max: +0.0002595, sum: +0.9999993
  layer3.3.bn2          (BatchNorm2d      ): min: +0.0000000, max: +0.0002595, sum: +0.9999993
  layer3.4.conv1        (Conv2d           ): min: +0.0000000, max: +0.0003304, sum: +0.9999993
  layer3.4.bn1          (BatchNorm2d      ): min: +0.0000000, max: +0.0003304, sum: +0.9999993
  layer3.4.relu         (ReLU             ): min: +0.0000000, max: +0.0002727, sum: +0.9999993
  layer3.4.conv2        (Conv2d           ): min: +0.0000000, max: +0.0002727, sum: +0.9999993
  layer3.4.bn2          (BatchNorm2d      ): min: +0.0000000, max: +0.0002727, sum: +0.9999993
  layer3.5.conv1        (Conv2d           ): min: +0.0000000, max: +0.0004377, sum: +0.9999994
  layer3.5.bn1          (BatchNorm2d      ): min: +0.0000000, max: +0.0004377, sum: +0.9999994
  layer3.5.relu         (ReLU             ): min: +0.0000000, max: +0.0004310, sum: +0.9999993
  layer3.5.conv2        (Conv2d           ): min: +0.0000000, max: +0.0004310, sum: +0.9999993
  layer3.5.bn2          (BatchNorm2d      ): min: +0.0000000, max: +0.0004310, sum: +0.9999993
  layer4.0.conv1        (Conv2d           ): min: +0.0000000, max: +0.0007261, sum: +0.9999992
  layer4.0.bn1          (BatchNorm2d      ): min: +0.0000000, max: +0.0007261, sum: +0.9999992
  layer4.0.relu         (ReLU             ): min: +0.0000000, max: +0.0007178, sum: +0.9999994
  layer4.0.conv2        (Conv2d           ): min: +0.0000000, max: +0.0007178, sum: +0.9999994
  layer4.0.bn2          (BatchNorm2d      ): min: +0.0000000, max: +0.0007178, sum: +0.9999994
  layer4.0.downsample.0 (Conv2d           ): min: +0.0000000, max: +0.0000000, sum: +0.0000000
  layer4.0.downsample.1 (BatchNorm2d      ): min: +0.0000000, max: +0.0000000, sum: +0.0000000
  layer4.1.conv1        (Conv2d           ): min: +0.0000000, max: +0.0008504, sum: +0.9999994
  layer4.1.bn1          (BatchNorm2d      ): min: +0.0000000, max: +0.0008504, sum: +0.9999994
  layer4.1.relu         (ReLU             ): min: +0.0000000, max: +0.0006153, sum: +0.9999993
  layer4.1.conv2        (Conv2d           ): min: +0.0000000, max: +0.0006153, sum: +0.9999993
  layer4.1.bn2          (BatchNorm2d      ): min: +0.0000000, max: +0.0006153, sum: +0.9999993
  layer4.2.conv1        (Conv2d           ): min: -0.0005434, max: +0.0020294, sum: +0.9999993
  layer4.2.bn1          (BatchNorm2d      ): min: -0.0005434, max: +0.0020294, sum: +0.9999993
  layer4.2.relu         (ReLU             ): min: -0.0092405, max: +0.0115777, sum: +0.9999994
  layer4.2.conv2        (Conv2d           ): min: -0.0092405, max: +0.0115777, sum: +0.9999994
  layer4.2.bn2          (BatchNorm2d      ): min: -0.0092405, max: +0.0115777, sum: +0.9999994
  avgpool               (AdaptiveAvgPool2d): min: -0.2474068, max: +0.3167669, sum: +0.9999998
  fc                    (Linear           ): min: +0.0000000, max: +1.0000000, sum: +1.0000000

Model: vgg11
  input        (input            ): min: -0.0002063, max: +0.0002277, sum: +0.9999397
  features.0   (Conv2d           ): min: -0.0000620, max: +0.0000807, sum: +0.9999397
  features.1   (ReLU             ): min: -0.0000620, max: +0.0000807, sum: +0.9999397
  features.2   (MaxPool2d        ): min: -0.0000620, max: +0.0000807, sum: +0.9999397
  features.3   (Conv2d           ): min: -0.0001189, max: +0.0001306, sum: +0.9999405
  features.4   (ReLU             ): min: -0.0001189, max: +0.0001306, sum: +0.9999405
  features.5   (MaxPool2d        ): min: -0.0001189, max: +0.0001306, sum: +0.9999405
  features.6   (Conv2d           ): min: -0.0001095, max: +0.0001081, sum: +0.9999412
  features.7   (ReLU             ): min: -0.0001095, max: +0.0001081, sum: +0.9999412
  features.8   (Conv2d           ): min: -0.0002049, max: +0.0002393, sum: +0.9999417
  features.9   (ReLU             ): min: -0.0002049, max: +0.0002393, sum: +0.9999417
  features.10  (MaxPool2d        ): min: -0.0002049, max: +0.0002393, sum: +0.9999418
  features.11  (Conv2d           ): min: -0.0001591, max: +0.0001764, sum: +0.9999424
  features.12  (ReLU             ): min: -0.0001591, max: +0.0001764, sum: +0.9999424
  features.13  (Conv2d           ): min: -0.0006126, max: +0.0006749, sum: +0.9999429
  features.14  (ReLU             ): min: -0.0006126, max: +0.0006749, sum: +0.9999429
  features.15  (MaxPool2d        ): min: -0.0006126, max: +0.0006749, sum: +0.9999428
  features.16  (Conv2d           ): min: -0.0031778, max: +0.0024524, sum: +0.9999434
  features.17  (ReLU             ): min: -0.0031778, max: +0.0024524, sum: +0.9999434
  features.18  (Conv2d           ): min: -0.0809088, max: +0.0687538, sum: +0.9999435
  features.19  (ReLU             ): min: -0.0809088, max: +0.0687538, sum: +0.9999435
  features.20  (MaxPool2d        ): min: -0.0809088, max: +0.0687538, sum: +0.9999434
  avgpool      (AdaptiveAvgPool2d): min: -0.0809091, max: +0.0687540, sum: +0.9999459
  classifier.0 (Linear           ): min: -0.1945583, max: +0.1917277, sum: +0.9999526
  classifier.1 (ReLU             ): min: -0.1945583, max: +0.1917277, sum: +0.9999526
  classifier.2 (Dropout          ): min: -0.1945583, max: +0.1917277, sum: +0.9999526
  classifier.3 (Linear           ): min: -0.1157313, max: +0.1835920, sum: +0.9999599
  classifier.4 (ReLU             ): min: -0.1157313, max: +0.1835920, sum: +0.9999599
  classifier.5 (Dropout          ): min: -0.1157313, max: +0.1835920, sum: +0.9999599
  classifier.6 (Linear           ): min: +0.0000000, max: +1.0000000, sum: +1.0000000

If you are on master: Do you have a more concrete example of the bug?
This would help me to pin the issue down.
E.g., the snippet above should produce relevance sums not equal to one. You can try to comment out the line with the layer_map= in the composite to enable relevance from the residual connection.

Maybe skipping the residual branch already fixes your issue?

@MikiFER
Copy link

MikiFER commented Nov 24, 2023

Hi @chr5tphr thank you for your response. I have since discovered that I had a misunderstanding about the EpsilonPlusFlat composite and have decided that pure Alpha1Beta0 rule is actually what I need because I want to obtain relevance map with only positive influence in the input.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants