Skip to content

[WIP] Generic library for neural collapse and several derivative works on the phenomenon.

Notifications You must be signed in to change notification settings

rhubarbwu/neural-collapse

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Neural Collapse

This package aims to be a reference implementation for the analysis of Neural Collapse (NC) (Papyan et al., 2020). We provide,

  1. Accumulators to collect embeddings from output representations from your pre-trained model.
  2. Measurement (kernel) functions for several canonical and modern NC metrics.
  3. Tiling support for memory-bound settings arising from large embeddings, many classes and/or limited parallel accelerator (e.g. GPU) memory.

Installation

# install from remote
pip install git+https://github.com/rhubarbwu/neural-collapse.git

# with FAISS
pip install git+https://github.com/rhubarbwu/neural-collapse.git#egg=neural_collapse[faiss]

# install locally from a repository clone [with FAISS]
git clone https://github.com/rhubarbwu/neural-collapse.git
pip install -e neural-collapse[faiss]

Usage

import neural_collapse as nc

We assume that you,

  • Already pre-trained your model or are in the training process with a programmable loop, where the top-layer classifier weights are available.
  • Have your iterable dataloader(s) available. Make sure your training data is the same as that with which you trained your model.
  • Have model evaluation functions or results; technically optional but ideal.

For use cases with large embeddings or many classes, we recommend using a hardware accelerator (e.g. cuda).

Accumulators

You'll need to collect (e.g. "accumulate") statistics from you learned representations. Here we outline a basic example on the MNIST dataset with K=10 classes and embeddings of size D=512.

from neural_collapse.accumulate import (CovarAccumulator, DecAccumulator,
                                        MeanAccumulator, VarNormAccumulator)

Mean Embedding Accumulators (for NC* in general)

mean_accum = MeanAccumulator(10, 512, "cuda")
for i, (images, labels) in enumerate(train_loader):
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)
    mean_accum.accumulate(Features.value, labels)
means, mG = mean_accum.compute()

Variance Accumulators (for NC1)

For measuring within-class variability collapse (NC1), you would typically collect within-class covariances (covar_accum below); note that this might be memory-intensive at order K*D*D.

covar_accum = CovarAccumulator(10, 512, "cuda", M=means)
var_norms_accum = VarNormAccumulator(10, 512, "cuda", M=means) # for CDNV
for i, (images, labels) in enumerate(train_loader):
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)
    covar_accum.accumulate(Features.value, labels, means)
    var_norms_accum.accumulate(Features.value, labels, means)
covar_within = covar_accum.compute()
var_norms, _ = var_norms_accum.compute() # for CDNV

NC1 can also be empirically measured using the class-distance normalized variance (CDNV) (Galanti et. al, 2021), which only requires collecting within-class variance norms at order K.

Decision Agreement Accumulators (for NC4)

Measuring the convergence of the linear classifier's behaviour to that of the implicit near-class centre (NCC) classifier has since been extended to generalizing to unseen (e.g. validation or test) data.

dec_accum = DecAccumulator(10, 512, "cuda", M=means, W=weights)
dec_accum.create_index(means) # optionally use FAISS index for NCC
for i, (images, labels) in enumerate(test_loader):
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)

    # mean embeddings (only) necessary again if not using FAISS index
    if dec_accum.index is None:
        dec_accum.accumulate(Features.value, labels, weights, means)
    else:
        dec_accum.accumulate(Features.value, labels, weights)

Out-of-Distribution (OoD) Means (for NC5)

For OoD detection (NC5) (Ammar et al., 2024), collect class-mean embeddings from an out-of-distribution dataset for OoD detection.

ood_mean_accum = MeanAccumulator(10, 512, "cuda")
for i, (images, labels) in enumerate(ood_loader):
    images, labels = images.to(device), labels.to(device)
    outputs = model(images)
    ood_mean_accum.accumulate(Features.value, labels)
_, mG_ood = ood_mean_accum.compute()

Measurements

Here's our snippet for an example on the MNIST dataset.

from neural_collapse.measure import (clf_ncc_agreement, covariance_pinv,
                                     covariance_ratio, orthogonality_deviation,
                                     self_duality_error, simplex_etf_error,
                                     variability_cdnv)

results = {
    "nc1_pinv": covariance_ratio(covar_within, means, mG),
    "nc1_svd": covariance_ratio(covar_within, means, mG, "svd"),
    "nc1_quot": covariance_ratio(covar_within, means, mG, "quotient"),
    "nc1_cdnv": variability_cdnv(var_norms, means),
    "nc2_etf_err": simplex_etf_error(means, mG),
    "nc2g_dist": kernel_stats(means, mG)[1],
    "nc2g_log": kernel_stats(means, mG, kernel=log_kernel)[1],
    "nc3_dual_err": self_duality_error(weights, means, mG),
    "nc3u_uni_dual": similarities(weights, means, mG).var().item(),
    "nc4_agree": clf_ncc_agreement(dec_accum),
    "nc5_ood_dev": orthogonality_deviation(means, mG_ood),
}

Pre-Centring Means

Where centring is required for means, you can include the global mean mG as a bias argument (as above), or pre-centre them (as below).

means_centred = means - mG
results = {
    "nc1_pinv": covariance_ratio(covar_within, means_centred),
    "nc1_svd": covariance_ratio(covar_within, means_centred, metric="svd"),
    "nc1_quot": covariance_ratio(covar_within, means_centred, metric="quotient"),
    "nc1_cdnv": variability_cdnv(var_norms, means),
    # ...
    "nc5_ood_dev": orthogonality_deviation(means, mG_ood),
}

Note that since the uncentred means are still needed for some measurements (such as CDNV) (and therefore cannot be discarded), storing pre-centred means may not be economical memory-wise if K and/or D are large.

Tiling & Reductions

For many of the NC measurement functions, we implement kernel tiling if large embeddings or many classes are straining your hardware memory. You may want to tune the tile square size to maximize accelerator throughput.

results = {
    # ...
    "nc1_cdnv": variability_cdnv(var_norms, means, tile_size=64),
    "nc2g_dist": kernel_stats(means, mG, tile_size=64)[1], # var
    "nc2g_log": kernel_stats(means, mG, kernel=log_kernel, tile_size=64)[1], # var
    # ...
}

After kernel_grid produces a symmetric measurement matrix, kernel_stats computes the mean ([0]) and variance ([1]) using triangle row folding.

Development

This project is under active development. Feel free to open issues for bugs, features, optimizations, or papers you would like (us) to implement.

References

About

[WIP] Generic library for neural collapse and several derivative works on the phenomenon.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •  

Languages