diff --git a/src/zennit/core.py b/src/zennit/core.py index 0bf2e6f..9519766 100644 --- a/src/zennit/core.py +++ b/src/zennit/core.py @@ -356,19 +356,6 @@ def collect_leaves(module): yield module -class Identity(torch.autograd.Function): - '''Identity to add a grad_fn to a tensor, so a backward hook can be applied.''' - @staticmethod - def forward(ctx, *inputs): - '''Forward identity.''' - return inputs - - @staticmethod - def backward(ctx, *grad_outputs): - '''Backward identity.''' - return grad_outputs - - class Hook: '''Base class for hooks to be used to compute layer-wise attributions.''' def __init__(self): @@ -381,29 +368,41 @@ def pre_forward(self, module, input): hook_ref = weakref.ref(self) @functools.wraps(self.backward) - def wrapper(grad_input, grad_output): + def wrapper(grad_input, grad_output, grad_sink): hook = hook_ref() if hook is not None and hook.active: - return hook.backward(module, grad_input, hook.stored_tensors['grad_output']) + result = hook.backward(module, grad_output, hook.stored_tensors['grad_output'], grad_sink=grad_sink) + if not isinstance(result, tuple): + result = (result,) + if grad_input is None: + return result[0] + return result return None if not isinstance(input, tuple): input = (input,) - # only if gradient required - if input[0].requires_grad: - # add identity to ensure .grad_fn exists - post_input = Identity.apply(*input) + post_input = tuple(tensor.view_as(tensor) for tensor in input) + + # hook required gradient sinks + for grad_sink, tensor in enumerate(post_input): # register the input tensor gradient hook - self.tensor_handles.append( - post_input[0].grad_fn.register_hook(wrapper) - ) - # work around to support in-place operations - post_input = tuple(elem.clone() for elem in post_input) - else: - # no gradient required - post_input = input - return post_input[0] if len(post_input) == 1 else post_input + if tensor.grad_fn is not None: + # grad_fn for inputs is here the view function applied above + self.tensor_handles.append( + tensor.grad_fn.register_hook(functools.partial(wrapper, grad_sink=grad_sink)) + ) + # hook required gradient sinks + for grad_sink, tensor in module.named_parameters(): + if tensor.requires_grad: + # TODO: use grad_fn (need to store parameter views for the model...), otherwise the hook could be + # called for unrelated gradients + self.tensor_handles.append( + tensor.register_hook(functools.partial(wrapper, None, grad_sink=grad_sink)) + ) + + # torch.nn.Module converts single tensors to tuples anyway, so we can always return a tuple here + return post_input def post_forward(self, module, input, output): '''Register a backward-hook to the resulting tensor right after the forward.''' @@ -413,28 +412,28 @@ def post_forward(self, module, input, output): def wrapper(grad_input, grad_output): hook = hook_ref() if hook is not None and hook.active: - return hook.pre_backward(module, grad_input, grad_output) + return hook.pre_backward(module, grad_output) return None - if not isinstance(output, tuple): - output = (output,) + hookable_output = output + if not isinstance(hookable_output, tuple): + hookable_output = (hookable_output,) # only if gradient required - if output[0].grad_fn is not None: + if hookable_output[0].requires_grad: # register the output tensor gradient hook self.tensor_handles.append( - output[0].grad_fn.register_hook(wrapper) + hookable_output[0].grad_fn.register_hook(wrapper) ) - return output[0] if len(output) == 1 else output - def pre_backward(self, module, grad_input, grad_output): + def pre_backward(self, module, grad_output): '''Store the grad_output for the backward hook''' self.stored_tensors['grad_output'] = grad_output def forward(self, module, input, output): '''Hook applied during forward-pass''' - def backward(self, module, grad_input, grad_output): + def backward(self, module, grad_input, grad_output, grad_sink): '''Hook applied during backward-pass''' def copy(self): @@ -522,18 +521,18 @@ def forward(self, module, input, output): '''Forward hook to save module in-/outputs.''' self.stored_tensors['input'] = input - def backward(self, module, grad_input, grad_output): + def backward(self, module, grad_input, grad_output, grad_sink): '''Backward hook to compute LRP based on the class attributes.''' - original_input = self.stored_tensors['input'][0].clone() + original_inputs = [tensor.view_as(tensor) for tensor in self.stored_tensors['input']] inputs = [] outputs = [] for in_mod, param_mod, out_mod in zip(self.input_modifiers, self.param_modifiers, self.output_modifiers): - input = in_mod(original_input).requires_grad_() + input_args = [in_mod(tensor).requires_grad_() for tensor in original_inputs] with ParamMod.ensure(param_mod)(module) as modified, torch.autograd.enable_grad(): - output = modified.forward(input) - output = out_mod(output) - inputs.append(input) - outputs.append(output) + output = modified.forward(*input_args) + # decide for which argument to compute the relevance + inputs.append(input_args[grad_sink] if isinstance(grad_sink, int) else getattr(modified, grad_sink)) + outputs.append(out_mod(output)) grad_outputs = self.gradient_mapper(grad_output[0], outputs) gradients = torch.autograd.grad( outputs, @@ -542,7 +541,7 @@ def backward(self, module, grad_input, grad_output): create_graph=grad_output[0].requires_grad ) relevance = self.reducer(inputs, gradients) - return tuple(relevance if original.shape == relevance.shape else None for original in grad_input) + return relevance def copy(self): '''Return a copy of this hook. diff --git a/src/zennit/rules.py b/src/zennit/rules.py index 4c7554e..7d95430 100644 --- a/src/zennit/rules.py +++ b/src/zennit/rules.py @@ -322,7 +322,7 @@ class Pass(Hook): If the rule of a layer shall not be any other, is elementwise and shall not be the gradient, the `Pass` rule simply passes upper layer relevance through to the lower layer. ''' - def backward(self, module, grad_input, grad_output): + def backward(self, module, grad_input, grad_output, grad_sink): '''Pass through the upper gradient, skipping the one for this layer.''' return grad_output @@ -399,16 +399,16 @@ def __init__(self, stabilizer=1e-6, zero_params=None): class ReLUDeconvNet(Hook): '''DeconvNet ReLU rule :cite:p:`zeiler2014visualizing`.''' - def backward(self, module, grad_input, grad_output): + def backward(self, module, grad_input, grad_output, grad_sink): '''Modify ReLU gradient according to DeconvNet :cite:p:`zeiler2014visualizing`.''' - return (grad_output[0].clamp(min=0),) + return grad_output[0].clamp(min=0) class ReLUGuidedBackprop(Hook): '''GuidedBackprop ReLU rule :cite:p:`springenberg2015striving`.''' - def backward(self, module, grad_input, grad_output): + def backward(self, module, grad_input, grad_output, grad_sink): '''Modify ReLU gradient according to GuidedBackprop :cite:p:`springenberg2015striving`.''' - return (grad_input[0] * (grad_output[0] > 0.),) + return grad_input[0] * (grad_output[0] > 0.) class ReLUBetaSmooth(Hook): @@ -433,6 +433,6 @@ def forward(self, module, input, output): '''Remember the input for the backward pass.''' self.stored_tensors['input'] = input - def backward(self, module, grad_input, grad_output): + def backward(self, module, grad_input, grad_output, grad_sink): '''Modify ReLU gradient to the smooth softplus gradient :cite:p:`dombrowski2019explanations`.''' - return (torch.sigmoid(self.beta_smooth * self.stored_tensors['input'][0]) * grad_output[0],) + return torch.sigmoid(self.beta_smooth * self.stored_tensors['input'][0]) * grad_output[0]