Skip to content

Commit

Permalink
Core: Multiple input/param gradient modification
Browse files Browse the repository at this point in the history
- change the core Hook to support the modification of multiple inputs
  and params
- for this, now each input and parameter that requires a gradient will
  be hooked, and a backward, which is aware of which the current 'sink'
  is, will be called for each
- use View instead of custom Identity to produce a .grad_fn

Note:
- this may be a breaking change for custom hooks based on the old
  implementation

TODO:
- finish implementation:
    - parameters have no grad_fn, and we cannot simply overwrite them
      with a view; hooking directly with tensor hooks is problematic
      when the parameters are used in different functions
    - there may be potentially a better approach than calling the
      backward function once per 'sink', although the current
      implementation may allow for better modularity
    - multiple outputs are still not supported, it may be worth to think
      how to do it, however, it may also be better to do this at a later
      stage
- implement tests
  - new tests for the new functionality: multiple inputs and params in
    hooks
  - fix old tests that assume the use of Identity and are not sink-aware
- add documentation
  • Loading branch information
chr5tphr committed Nov 4, 2022
1 parent d46f3e7 commit 8f11583
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 51 deletions.
87 changes: 43 additions & 44 deletions src/zennit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,19 +356,6 @@ def collect_leaves(module):
yield module


class Identity(torch.autograd.Function):
'''Identity to add a grad_fn to a tensor, so a backward hook can be applied.'''
@staticmethod
def forward(ctx, *inputs):
'''Forward identity.'''
return inputs

@staticmethod
def backward(ctx, *grad_outputs):
'''Backward identity.'''
return grad_outputs


class Hook:
'''Base class for hooks to be used to compute layer-wise attributions.'''
def __init__(self):
Expand All @@ -381,29 +368,41 @@ def pre_forward(self, module, input):
hook_ref = weakref.ref(self)

@functools.wraps(self.backward)
def wrapper(grad_input, grad_output):
def wrapper(grad_input, grad_output, grad_sink):
hook = hook_ref()
if hook is not None and hook.active:
return hook.backward(module, grad_input, hook.stored_tensors['grad_output'])
result = hook.backward(module, grad_output, hook.stored_tensors['grad_output'], grad_sink=grad_sink)
if not isinstance(result, tuple):
result = (result,)
if grad_input is None:
return result[0]
return result
return None

if not isinstance(input, tuple):
input = (input,)

# only if gradient required
if input[0].requires_grad:
# add identity to ensure .grad_fn exists
post_input = Identity.apply(*input)
post_input = tuple(tensor.view_as(tensor) for tensor in input)

# hook required gradient sinks
for grad_sink, tensor in enumerate(post_input):
# register the input tensor gradient hook
self.tensor_handles.append(
post_input[0].grad_fn.register_hook(wrapper)
)
# work around to support in-place operations
post_input = tuple(elem.clone() for elem in post_input)
else:
# no gradient required
post_input = input
return post_input[0] if len(post_input) == 1 else post_input
if tensor.grad_fn is not None:
# grad_fn for inputs is here the view function applied above
self.tensor_handles.append(
tensor.grad_fn.register_hook(functools.partial(wrapper, grad_sink=grad_sink))
)
# hook required gradient sinks
for grad_sink, tensor in module.named_parameters():
if tensor.requires_grad:
# TODO: use grad_fn (need to store parameter views for the model...), otherwise the hook could be
# called for unrelated gradients
self.tensor_handles.append(
tensor.register_hook(functools.partial(wrapper, None, grad_sink=grad_sink))
)

# torch.nn.Module converts single tensors to tuples anyway, so we can always return a tuple here
return post_input

def post_forward(self, module, input, output):
'''Register a backward-hook to the resulting tensor right after the forward.'''
Expand All @@ -413,28 +412,28 @@ def post_forward(self, module, input, output):
def wrapper(grad_input, grad_output):
hook = hook_ref()
if hook is not None and hook.active:
return hook.pre_backward(module, grad_input, grad_output)
return hook.pre_backward(module, grad_output)
return None

if not isinstance(output, tuple):
output = (output,)
hookable_output = output
if not isinstance(hookable_output, tuple):
hookable_output = (hookable_output,)

# only if gradient required
if output[0].grad_fn is not None:
if hookable_output[0].requires_grad:
# register the output tensor gradient hook
self.tensor_handles.append(
output[0].grad_fn.register_hook(wrapper)
hookable_output[0].grad_fn.register_hook(wrapper)
)
return output[0] if len(output) == 1 else output

def pre_backward(self, module, grad_input, grad_output):
def pre_backward(self, module, grad_output):
'''Store the grad_output for the backward hook'''
self.stored_tensors['grad_output'] = grad_output

def forward(self, module, input, output):
'''Hook applied during forward-pass'''

def backward(self, module, grad_input, grad_output):
def backward(self, module, grad_input, grad_output, grad_sink):
'''Hook applied during backward-pass'''

def copy(self):
Expand Down Expand Up @@ -522,18 +521,18 @@ def forward(self, module, input, output):
'''Forward hook to save module in-/outputs.'''
self.stored_tensors['input'] = input

def backward(self, module, grad_input, grad_output):
def backward(self, module, grad_input, grad_output, grad_sink):
'''Backward hook to compute LRP based on the class attributes.'''
original_input = self.stored_tensors['input'][0].clone()
original_inputs = [tensor.view_as(tensor) for tensor in self.stored_tensors['input']]
inputs = []
outputs = []
for in_mod, param_mod, out_mod in zip(self.input_modifiers, self.param_modifiers, self.output_modifiers):
input = in_mod(original_input).requires_grad_()
input_args = [in_mod(tensor).requires_grad_() for tensor in original_inputs]
with ParamMod.ensure(param_mod)(module) as modified, torch.autograd.enable_grad():
output = modified.forward(input)
output = out_mod(output)
inputs.append(input)
outputs.append(output)
output = modified.forward(*input_args)
# decide for which argument to compute the relevance
inputs.append(input_args[grad_sink] if isinstance(grad_sink, int) else getattr(modified, grad_sink))
outputs.append(out_mod(output))
grad_outputs = self.gradient_mapper(grad_output[0], outputs)
gradients = torch.autograd.grad(
outputs,
Expand All @@ -542,7 +541,7 @@ def backward(self, module, grad_input, grad_output):
create_graph=grad_output[0].requires_grad
)
relevance = self.reducer(inputs, gradients)
return tuple(relevance if original.shape == relevance.shape else None for original in grad_input)
return relevance

def copy(self):
'''Return a copy of this hook.
Expand Down
14 changes: 7 additions & 7 deletions src/zennit/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ class Pass(Hook):
If the rule of a layer shall not be any other, is elementwise and shall not be the gradient, the `Pass` rule simply
passes upper layer relevance through to the lower layer.
'''
def backward(self, module, grad_input, grad_output):
def backward(self, module, grad_input, grad_output, grad_sink):
'''Pass through the upper gradient, skipping the one for this layer.'''
return grad_output

Expand Down Expand Up @@ -399,16 +399,16 @@ def __init__(self, stabilizer=1e-6, zero_params=None):

class ReLUDeconvNet(Hook):
'''DeconvNet ReLU rule :cite:p:`zeiler2014visualizing`.'''
def backward(self, module, grad_input, grad_output):
def backward(self, module, grad_input, grad_output, grad_sink):
'''Modify ReLU gradient according to DeconvNet :cite:p:`zeiler2014visualizing`.'''
return (grad_output[0].clamp(min=0),)
return grad_output[0].clamp(min=0)


class ReLUGuidedBackprop(Hook):
'''GuidedBackprop ReLU rule :cite:p:`springenberg2015striving`.'''
def backward(self, module, grad_input, grad_output):
def backward(self, module, grad_input, grad_output, grad_sink):
'''Modify ReLU gradient according to GuidedBackprop :cite:p:`springenberg2015striving`.'''
return (grad_input[0] * (grad_output[0] > 0.),)
return grad_input[0] * (grad_output[0] > 0.)


class ReLUBetaSmooth(Hook):
Expand All @@ -433,6 +433,6 @@ def forward(self, module, input, output):
'''Remember the input for the backward pass.'''
self.stored_tensors['input'] = input

def backward(self, module, grad_input, grad_output):
def backward(self, module, grad_input, grad_output, grad_sink):
'''Modify ReLU gradient to the smooth softplus gradient :cite:p:`dombrowski2019explanations`.'''
return (torch.sigmoid(self.beta_smooth * self.stored_tensors['input'][0]) * grad_output[0],)
return torch.sigmoid(self.beta_smooth * self.stored_tensors['input'][0]) * grad_output[0]

0 comments on commit 8f11583

Please sign in to comment.