Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepLIFT with Softmax - normalization error #1367

Open
AITheorem opened this issue Oct 13, 2024 · 0 comments
Open

DeepLIFT with Softmax - normalization error #1367

AITheorem opened this issue Oct 13, 2024 · 0 comments

Comments

@AITheorem
Copy link

In the implementation of DeepLIFT, there is a normalization step that gets applied to contributions from the Softmax module. I am pretty sure this step comes from a misreading of the DeepLIFT paper. Section 3.6 of the paper (link) recommends that:

  • (a) In the case of Softmax outputs, we may prefer to compute contributions to the logits rather than contributions to the Softmax outputs.
  • (b) If we compute contributions to the logits, then we can normalize the contributions.

So we actually shouldn't be normalizing the contributions to the Softmax outputs at all.

Expected behaviour

  1. We should remove the normalization step from the Softmax code.
  2. We should include a warning somewhere that the user may prefer to use logits rather than Softmax output.
    • Maybe in the docstring?
    • Or potentially a warning in the code when we spot a softmax, which could also include a warning about this change of behaviour (1).
  3. In an ideal world there would be an additional parameter for (a) which specifies "use the logits instead of the softmax outputs".
    • But I'm not sure we have a concept of "this is the penultimate layer". Maybe we just ask people to pass in a forward func that outputs logits.
  4. In an ideal world there would be an additional parameter for (b) which specifies "we are calculating contributions to the logits, please normalize the contributions in the (linear) logit layer".
    • But to do this we would need a hook around the penultimate linear layer (or around the final linear layer if we expect the forward func to produce logits) so this also depends on identifying the penultimate / final layer.

I am happy to implement these myself but I would appreciate feedback, for (3/4) in particular.

Current code with normalization

def softmax(
    module: Module,
    inputs: Tensor,
    outputs: Tensor,
    grad_input: Tensor,
    grad_output: Tensor,
    eps: float = 1e-10,
) -> Tensor:
    delta_in, delta_out = _compute_diffs(inputs, outputs)

    grad_input_unnorm = torch.where(
        abs(delta_in) < eps, grad_input, grad_output * delta_out / delta_in
    )
    # normalizing
    n = grad_input.numel()

    # updating only the first half
    new_grad_inp = grad_input_unnorm - grad_input_unnorm.sum() * 1 / n
    return new_grad_inp
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant