[WIP] modular-pytorch-lightning, WARNING: The repository is currently under development, and is unstable.
What is modular-pytorch-Lightning-Collections⚡
(LightCollections⚡️) for?
- Ever wanted to train
tresnetm50
models and apply TTA(test-time augmentation) or SWA(stocahstic weight averaging) to enhance performance? Apply sharpness-aware minimization to semantic segmentation models and measure the difference in calibration? LightCollections is a framework that utilize and connects existing libraries so experiments can be run effortlessly. - Although many popular repositories provide great implementations of algorithms, they are often fragmented and tedious to use in cooperation. LightCollection wraps many existing repositories into components of
pytorch-lightning
. We aim to provide training procedures of various subtasks in Computer Vision with a collection ofLightningModule
and utilities for easily using model architecture, metrics, dataset, and training algorithms. - The components can be used through our system or simply imported from outside to be used in your
pytorch
orpytorch-lightning
project. Currently, the following frameworks are integrated intoLightCollections
:torchvision.models
for models,torchvision.transforms
for transforms, optimizers and learning rate schedules frompytorch
.- Network architecture and weights from
timm
. - Object detection frameworks and techniques from
mmdetection
- Keypoint detection(pose estimation) frameworks and techniques from
mmpose
inagenet21k
pretrained weights and feature to load model weights from url /.pth
file.TTAch
for test-time augmentation.- Metrics implemented in
torchmetrics
. - WIP & future TODO:
- Data augmentation from
albumentations
- Semantic segmentation models and weights from
mmsegmentation
- Data augmentation from
A number of algorithms and research papers are also adopted into our framework. Please refer to the examples below for more information.
%cd /content
!git clone https://github.com/krenerd/awesome-modular-pytorch-lightning
%cd awesome-modular-pytorch-lightning
!pip install -r requirements.txt -q
# (optional) use `wandb` to log progress.
!wandb login
After installing required packages, you can run the following experiments on COLAB.
- CIFAR10 image classification with ResNet18.
!python train.py --name DUMMY-CIFAR10-ResNet18 --configs \
configs/vision/classification/resnet-cifar10.yaml \
configs/vision/models/resnet/resnet18-custom.yaml \
configs/data/cifar10-kuangliu.yaml \
configs/device/gpu.yaml \
configs/utils/wandb.yaml \
configs/utils/train.yaml
- Transfer learning experiments on Stanford Dogs dataset using TResNet-M
# uncomment when using TResNet models, this takes quite long
!pip install git+https://github.com/mapillary/[email protected] -q
!python train.py --name TResNetM-StanfordDogs --config \
configs/data/transfer-learning/training/colab-modification.yaml \
configs/data/transfer-learning/training/random-init.yaml \
configs/data/transfer-learning/cifar10-224.yaml \
configs/data/augmentation/randomresizecrop.yaml \
configs/vision/models/tresnetm.yaml \
configs/device/gpu.yaml \
configs/utils/wandb.yaml \
configs/utils/train.yaml
- Object detection based on
FasterRCNN-FPN
andmmdetection
onvoc0712
dataset.
# refer to: https://mmcv.readthedocs.io/en/latest/get_started/installation.html
# install dependencies: (use cu113+torch1.12 because colab has CUDA 11.3)
# install mmcv-full thus we could use CUDA operators
!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.12.0/index.html
# Install mmdetection
!rm -rf mmdetection
!git clone https://github.com/open-mmlab/mmdetection.git
%cd mmdetection
!pip install -e .
# clone MPL
%cd /content
!git clone https://github.com/krenerd/awesome-modular-pytorch-lightning
%cd awesome-modular-pytorch-lightning
!pip install -r requirements.txt -q
# setup voc07+12 dataset
!python tools/download_dataset.py --dataset-name voc0712 --save-dir data --delete --unzip
# run experiment
!python train.py --name voc0712-FasterRCNN-FPN-ResNet50 --config \
configs/vision/object-detection/mmdet/faster-rcnn-r50-fpn-voc0712.yaml \
configs/vision/object-detection/mmdet/mmdet-base.yaml \
configs/data/voc0712-mmdet-no-tta.yaml \
configs/data/voc0712-mmdet.yaml \
configs/device/gpu.yaml \
configs/utils/wandb.yaml \
configs/utils/train.yaml
- 2D Pose estimation based on
HRNet
andmmdetection
inMPII
dataset.
# train HRNet pose estimation on MPII
# refer to: https://mmcv.readthedocs.io/en/latest/get_started/installation.html
# install dependencies: (use cu113+torch1.12 because colab has CUDA 11.3)
# install mmcv-full thus we could use CUDA operators
!pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.12.0/index.html
# Install mmdetection
!rm -rf mmpose
!git clone https://github.com/open-mmlab/mmpose.git
%cd mmpose
!pip install -e .
# clone MPL
%cd ..
# setup voc07+12 dataset
!python tools/download_dataset.py --dataset-name mpii --save-dir data/mpii --delete --unzip
# run experiment
!python train.py --name MPII-HRNet32 --config \
configs/vision/pose/mmpose/hrnet_w32_256x256.yaml \
configs/vision/pose/mmpose/mmpose-base.yaml \
configs/data/pose-2d/mpii-hrnet.yaml \
configs/device/gpu.yaml \
configs/utils/wandb.yaml \
configs/utils/train.yaml
- Supervised VideoPose3D 3D-pose estimation on
Human3.6M
dataset
!python tools/download_dataset.py --dataset-name human36m_annotation --unzip --save-dir human36m --delete --unzip
!python train.py --name Temporal-baseline-bs1024-lr0.001 --config \
configs/vision/pose-lifting/temporal.yaml \
configs/data/human36/temproal-videopose3d.yaml \
configs/data/human36/normalization.yaml \
configs/device/gpu.yaml \
configs/utils/wandb.yaml \
configs/utils/train.yaml
LightCollections can also be used as a library for extending your pytorch lightning code. train.py
simply conveys the config file to the Experiment
class defined in main.py
to build components such as dataset, dataloaders, models, and callbacks, which in tern uses components defined in catatlog
.
...
experiment = Experiment(cfg)
experiment.initialize_environment(cfg=cfg)
datasets = experiment.setup_dataset(
dataset_cfg=cfg["dataset"],
transform_cfg=cfg["transform"],
)
dataloaders = experiment.setup_dataloader(
datasets=datasets,
dataloader_cfg=cfg["dataloader"],
)
train_dataloader, val_dataloader = dataloaders["trn"], dataloaders["val"]
model = experiment.setup_model(model_cfg=cfg["model"], training_cfg=cfg["training"])
logger_and_callbacks = experiment.setup_callbacks(cfg=cfg)
...
You don't neccessarily need to create every component using the Experiment
class. For example, if you wish to use a custom dataset instead, you can skip experiment.setup_dataset
and feed your custom dataset to experiment.setup_dataloader
. The Experiment
class simply manages constant global variables such as label map, normalization mean and standard deviation, and manages a common log directory to conveniently create the components.
In an example implemented in tools/hyperparameter_sweep.py
, I was able to implement hyperparameter sweeping using the Experiment
class.
Training involves many configs. LightCollections
implements a cascading config system where we use multiple layers of config
files to define differnt parts of the experiment. For example, in the CIFAR10 example above, we combine 6 config files.
configs/vision/classification/resnet-cifar10.yaml \
configs/vision/models/resnet/resnet18-custom.yaml \
configs/data/cifar10-kuangliu.yaml \
configs/device/gpu.yaml \
configs/utils/wandb.yaml \
configs/utils/train.yaml
Each layer defines something different, such as the dataset, network architecture, or training procedure. If we wanted to use a ResNet50
model instead, we may replace
configs/vision/models/resnet/resnet18-custom.yaml
-> configs/vision/models/resnet/resnet50-custom.yaml
these cascading config files are baked at the start of train.py
. Configs in front have higher priority. These baked config files are logged under configs/logs/{experiment_name}.(yaml/pkl/json)
for logging and reproduction purpose.
- If not implemented yet, you may take an instance of
main.py: Experiment
and override any part of it. - Training procedure (
LightningModule
): List of available training procedures are listed inlightning/trainers.py
- Model architectures:
- Backbone models implemented in
torchvision.models
can be used. - Backbone models implemented in
timm
can be used. - Although we highly recommend using
timm
, as they provide a large variaty of computer vision models and their models are throughly evaluated, custom implementations of some architectures are listed incatalog/models/__init__.py
.
- Backbone models implemented in
- Dataset
- Dataset: currently only
torchvision
datasets are supported byExperiment
, howevertorchvision.datasets.ImageFolder
can be used to load from custom dataset. In addition, you may just use a custom dataset and combine it with the transforms, model and training feature of the repo. - Transformations(data augmentation): Transforms must be listed in one in [
data/transforms/vision/__init__.py
]
- Dataset: currently only
- Other features
- Optimizers
- Metrics / loss
The results of experiments such as model checkpoints, logs, and the config file used to run the experiment is logged under awesome-modular-pytorch-lightning/results/{exp_name}
. In particular, the results/{exp_name}/configs/cfg.yaml
file which contains the config file can be useful when reproducing experiments or rechecking hyperparameters.
For computer vision models, we recommend borrowing architecture from timm
as they provide robust implementations and pretrained weights for a wide variety of architectures. We remove the classification head from the original models.
timm
provides a wide variety of architectures for computer vision. timm.list_models()
returns a complete list of available models in timm. Models can be created with timm.create_model()
.
An example of creating a resnet50
model using timm
:
model = timm.create_model("resnet50", pretrained=True)
To use timm
models,
- set
model.backbone.name
toTimmNetwork
. - set
model.backbone.args.name
to the model name. - set additional arguments in
model.backbone.cfg
. - Refer to:
configs/vision/models/resnet/resnet50-timm.yaml
model:
backbone:
name: "TimmNetwork"
args:
name: "resnet50"
args:
pretrained: True
out_features: 2048
torchvision.models
also provide a number of architectures for computer vision. The list of models can be found here.
An example of creating a resnet50
model using torchvision
:
model = torchvision.models.resnet50()
To use timm
models,
- set
model.backbone.name
toTorchvisionNetwork
. - set
model.backbone.args.name
to the model name. - set additional arguments in
models.backbone.args
. - set
model.backbone.drop_after
to only use feature extractor. - Refer to:
configs/vision/models/resnet/resnet50-torchvision.yaml
model:
backbone:
name: "TorchvisionNetwork"
args:
name: "resnet50"
args:
pretrained: False
drop_after: "avgpool"
out_features: 2048
- Paper: https://arxiv.org/abs/1409.4842
- Note: Common data augmentation strategy for ImageNet using RandomResizedCrop.
- Refer to:
configs/data/augmentation/randomresizecrop.yaml
transform: [
[
"trn",
[
{
"name": "RandomResizedCropAndInterpolation",
"args":
{
"size": 224,
"scale": [0.08, 1.0],
"ratio": [0.75, 1.3333],
"interpolation": "random",
},
},
{
"name": "TorchvisionTransform",
"args": { "name": "RandomHorizontalFlip" },
},
# more data augmentation (rand augment, auto augment, ...)
],
],
[
# standard approach to use images cropped to the central 87.5% for validation
"val,test",
[
{
"name": "Resize",
"args": { "size": [256, 256], "interpolation": "bilinear" },
},
{
"name": "TorchvisionTransform",
"args": { "name": "CenterCrop", "args": { "size": [224, 224] } },
},
# more data augmentation (rand augment, auto augment, ...)
],
],
[
"trn,val,test",
[
{ "name": "ImageToTensor", "args": {} },
{
"name": "Normalize",
"args":
{
"mean": "{const.normalization_mean}",
"std": "{const.normalization_std}",
},
},
],
],
]
- Paper: https://arxiv.org/abs/1909.13719
- Note: Commonly used data augmentation strategy for image classification.
- Refer to:
configs/data/augmentation/randaugment.yaml
transform:
...
{
"name": "TorchvisionTransform",
"args":
{
"name": "RandAugment",
"args": { "num_ops": 2, "magnitude": 9 }
},
},
...
Refer to: configs/algorithms/data_augmentation/randaugment.yaml
- Paper: https://arxiv.org/abs/2103.10158
- Note: Commonly used data augmentation strategy for image classification.
- Refer to:
configs/data/augmentation/trivialaugment.yaml
transform:
...
{
"name": "TorchvisionTransform",
"args":
{
"name": "TrivialAugmentWide",
"args": { "num_magnitude_bins": 31 },
},
},
...
- Paper: https://arxiv.org/abs/1710.09412
- Note: Commonly used data augmentation strategy for image classification. As the labels are continuous values, the loss function should be modified accordingly.
- Refer to:
configs/data/augmentation/mixup/mixup.yaml
training:
mixup_cutmix:
mixup_alpha: 1.0
cutmix_alpha: 0.0
cutmix_minmax: null
prob: 1.0
switch_prob: 0.5
mode: "batch"
correct_lam: True
label_smoothing: 0.1
num_classes: 1000
model:
modules:
loss_fn:
name: "SoftTargetCrossEntropy"
- Paper: https://arxiv.org/abs/2103.10158
- Note: Commonly used data augmentation strategy for image classification. As the labels are continuous values, the loss function should be modified accordingly.
- Refer to:
configs/data/augmentation/mixup/cutmix.yaml
training:
mixup_cutmix:
mixup_alpha: 0.0
cutmix_alpha: 1.0
cutmix_minmax: null
prob: 1.0
switch_prob: 0.5
mode: "batch"
correct_lam: True
label_smoothing: 0.1
num_classes: "{const.num_classes}"
model:
modules:
loss_fn:
name: "SoftTargetCrossEntropy"
- Paper: https://arxiv.org/abs/1708.04552
- Note: Commonly used data augmentation strategy for image classification.
- Refer to:
configs/data/augmentation/cutout.yaml
andconfigs/data/augmentation/cutout_multiple.yaml
transform:
...
[
"trn",
[
...
{
# CutOut!!
"name": "CutOut",
"args": { "mask_size": 0.5, "num_masks": 1 },
},
],
],
...
- Paper: https://arxiv.org/abs/2110.00476
- Note: By default, models trained in
timm
randomly switches betweenMixup
andCutMix
data augmentation. This is found to be effective in their paper. - Refer to:
configs/data/augmentation/mixup/mixup_cutmix.yaml
training:
mixup_cutmix:
mixup_alpha: .8
cutmix_alpha: 1.0
cutmix_minmax: null
prob: 1.0
switch_prob: 0.5
mode: "batch"
correct_lam: True
label_smoothing: 0.1
num_classes: 1000
model:
modules:
loss_fn:
name: "SoftTargetCrossEntropy"
- Note: Commonly used regularization strategy.
- Refer to: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
model:
modules:
loss_fn:
name: "CrossEntropyLoss"
args:
label_smoothing: 0.1
- Note: Commonly used regularization strategy.
- Refer to: torch.optim.SGD and
configs/vision/classification/resnet-cifar10.yaml
training:
optimizer_cfg:
weight_decay: 0.0005
- Paper: https://jmlr.org/papers/v15/srivastava14a.html
- Note: Commonly used regularization strategy.
- Refer to:
configs/vision/classification/resnet-cifar10.yaml
model:
modules:
classifier:
args:
dropout: 0.2
- Paper: https://arxiv.org/abs/2106.14448
- Note: Regularization strategy that minimizes the KL-divergence between the output distributions of two sub-models sampled by dropout.
- Refer to:
configs/algorithms/rdrop.yaml
training:
rdrop:
alpha: 0.6
- Paper: https://arxiv.org/abs/2010.01412
- Note: Sharpness aware minimization aims at finding flat minimas. It is demonstrated to improve training speed, generalization, robustness to label noise. However, two backpropagation is needed at every opimization step which doubles the training time.
- Refer to:
configs/algorithms/sharpness-aware-minimization.yaml
training:
sharpness-aware:
rho: 0.05
- Paper: https://arxiv.org/abs/2204.12511
- Note: The authors derive the taylor expansion of cross entropy and demonstrate that modifying the coefficient of the first-order term can improve performance.
- Refer to:
configs/algorithms/loss/poly1.yaml
model:
modules:
loss_fn:
name: "PolyLoss"
file: "loss"
args:
eps: 2.0
- Paper: https://arxiv.org/abs/2110.00476
- Note: The authors show that BCE loss can be used for classification tasks and shows similar or better performance.
- Refer to:
configs/algorithms/loss/classification_bce.yaml
model:
modules:
loss_fn:
name: "OneToAllBinaryCrossEntropy"
file: "loss"
(TODO)
- Paper: https://arxiv.org/abs/1503.02531
- Note: Distill knowledge from large network to small network by minimizing the KL divergence of the teacher and student prediction.
- Refer to:
TODO
- Note: Used to preventing gradient explosion and stabilize the training by clipping large gradients. Recently, it is demonstrated to have a number of benefits.
- Refer to:
configs/algorithms/optimizer/gradient_clipping.yaml
andconfigs/algorithms/optimizer/gradient_clipping_maulat_optimization.yaml
trainer:
gradient_clip_val: 1.0
- Paper: https://arxiv.org/abs/1803.05407
- Note: Average multiple checkpoints during training for better performance. An awesome overview of the algorithm is provided by pytorch. Luckily,
pytorch-lightning
provides an easy-to-use callback that implements SWA. To train a SWA model from an existing checkpoint, you may setswa_epoch_start: 0.0
. - Refer to:
configs/algorithms/swa.yaml
callbacks:
StochasticWeightAveraging:
name: "StochasticWeightAveraging"
file: "lightning"
args:
swa_lrs: 0.02 # typicall x0.2 ~ x0.5 of initial lr
swa_epoch_start: 0.75
annealing_epochs: 5 # smooth the connection between lr schedule and SWA.
annealing_strategy: "cos"
avg_fn: null
-
timm
ortorchvision
models have an argument calledpretrained
for loading a pretrained feature extractor(typically ImageNet trained). -
To load from custom checkpoints, you can specify a url or path to the state dict in
model.backbone.weights
. For example,configs/vision/models/imagenet21k/imagenet21k_resnet50.yaml
loads ImageNet21K checkpoints from a url proposed in the paper: ImageNet-21K Pretraining for the Masses. Path to state dict can be provided inmodel.backbone.weights.state_dict_path
instead of the url. This is implemented inlightning/base.py:L23
model:
backbone:
name: "TimmNetwork"
args:
name: "resnet50"
args:
pretrained: False
out_features: 2048
weights:
is_ckpt: True
url: "https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/resnet50_miil_21k.pth"
# alternatively, `state_dict_path: {PATH TO STATE DICT}`
To resume from a checkpoint, provide the path to state dict in model.state_dict_path
. Checkpoints generated using ModelCheckpoint
callback contain state dict inside the state_dict
key while saving using torch.save(model.state_dict())
directly saves the state dict. The is_ckpt
argument should be true if the state dict is generated through the ModelCheckpoint
callback.
model:
is_ckpt: True # True / False according to the type of the state dict.
state_dict_path: {PATH TO STATE DICT}
- Paper: https://arxiv.org/abs/1708.07120
- Note: Linearly increases learning rate from 0 to maximum for the first half of training then linearly decreases to 0. Commonly used learning rate schedule.
- Refer to:
configs/algorithms/lr-schedule/1cycle.yaml
and torch.optim.lr_scheduler.OneCycleLR
training:
lr_scheduler:
name: "1cycle"
args:
# Original version updates per-batch but we modify to update pre-epoch.
pct_start: 0.3
max_lr: "{training.lr}"
anneal_strategy: "linear"
total_steps: "{training.epochs}"
cfg:
interval: "epoch"
- Paper: https://arxiv.org/abs/1608.03983
- Note: Decays the learning from the initial value to 0 via a cosine function. Commonly used learning rate schedule.
- Refer to:
configs/vision/classification/resnet-cifar10.yaml
and torch.optim.lr_scheduler.CosineAnnealingLR
training:
lr_scheduler:
name: "cosine"
args:
T_max: "{training.epochs}"
cfg:
interval: "epoch"
- Paper: https://arxiv.org/abs/2103.10158
- Note: Commonly used strategy to stabilize training at early stages.
- Refer to:
configs/algorithms/lr-schedule/warmup.yaml
training:
lr_warmup:
multiplier: 1
total_epoch: 5
- Refer to:
configs/algorithms/tta/hvflip.yaml
model:
tta:
name: "ClassificationTTAWrapper"
args:
output_label_key: "logits"
merge_mode: "mean"
transforms:
- name: "HorizontalFlip"
- name: "VerticalFlip"
- and
configs/algorithms/tta/rotation.yaml
model:
tta:
name: "ClassificationTTAWrapper"
args:
merge_mode: "mean"
transforms:
- name: "HorizontalFlip"
- name: "Rotation"
args:
angles:
- 0
- 30
- 60
- 90
- 120
- 150
- 180
- 210
- 240
- 270
- 300
- 330
Currently torchvision
datasets are supported by Experiment
, however you could use torchvision.datasets.ImageFolder
to load from custom dataset.
Transformations(data augmentation): Transforms must be listed in one in [data/transforms/vision/__init__.py
]
- Contact: [email protected]