-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d18ca93
commit ed1270c
Showing
7 changed files
with
768 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# UMT Project | ||
|
||
[Unmasked Teacher: Towards Training-Efficient Video Foundation Models](https://arxiv.org/abs/2303.16058) | ||
|
||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
<!-- [ABSTRACT] --> | ||
|
||
Video Foundation Models (VFMs) have received limited exploration due to high computational costs and data scarcity. Previous VFMs rely on Image Foundation Models (IFMs), which face challenges in transferring to the video domain. Although VideoMAE has trained a robust ViT from limited data, its low-level reconstruction poses convergence difficulties and conflicts with high-level cross-modal alignment. This paper proposes a training-efficient method for temporal-sensitive VFMs that integrates the benefits of existing methods. To increase data efficiency, we mask out most of the low-semantics video tokens, but selectively align the unmasked tokens with IFM, which serves as the UnMasked Teacher (UMT). By providing semantic guidance, our method enables faster convergence and multimodal friendliness. With a progressive pre-training framework, our model can handle various tasks including scene-related, temporal-related, and complex video-language understanding. Using only public sources for pre-training in 6 days on 32 A100 GPUs, our scratch-built ViT-L/16 achieves state-of-the-art performances on various video tasks. | ||
|
||
<!-- [IMAGE] --> | ||
|
||
<div align=center> | ||
<img src="https://github-production-user-asset-6210df.s3.amazonaws.com/58767402/262291190-bdaa6899-e1d6-460f-b329-23d8b38511f3.png" width="800"/> | ||
</div> | ||
|
||
## Usage | ||
|
||
### Setup Environment | ||
|
||
Please refer to [Installation](https://mmaction2.readthedocs.io/en/latest/get_started/installation.html) to install MMAction2. | ||
|
||
Assume that you are located at `$MMACTION2/projects/umt`. | ||
|
||
Add the current folder to `PYTHONPATH`, so that Python can find your code. Run the following command in the current directory to add it. | ||
|
||
> Please run it every time after you opened a new shell. | ||
```shell | ||
export PYTHONPATH=`pwd`:$PYTHONPATH | ||
``` | ||
|
||
### Data Preparation | ||
|
||
Prepare the Kinetics dataset according to the [instruction](https://github.com/open-mmlab/mmaction2/tree/main/tools/data/kinetics#readme). | ||
|
||
Create a symbolic link from `$MMACTION2/data` to `./data` in the current directory, so that Python can locate your data. Run the following command in the current directory to create the symbolic link. | ||
|
||
```shell | ||
ln -s ../../data ./data | ||
``` | ||
|
||
### Testing commands | ||
|
||
**To test with single GPU:** | ||
|
||
```bash | ||
mim test mmaction configs/umt-base-p16-res224_kinetics710-pre-ft_u8_k400-rgb.py --checkpoint $CHECKPOINT | ||
``` | ||
|
||
**To test with multiple GPUs:** | ||
|
||
```bash | ||
mim test mmaction configs/umt-base-p16-res224_kinetics710-pre-ft_u8_k400-rgb.py --checkpoint $CHECKPOINT --launcher pytorch --gpus 8 | ||
``` | ||
|
||
**To test with multiple GPUs by slurm:** | ||
|
||
```bash | ||
mim test mmaction configs/umt-base-p16-res224_kinetics710-pre-ft_u8_k400-rgb.py --checkpoint $CHECKPOINT --launcher slurm \ | ||
--gpus 8 --gpus-per-node 8 --partition $PARTITION | ||
``` | ||
|
||
## Results | ||
|
||
### Kinetics400 | ||
|
||
| frame sampling strategy | resolution | backbone | pretrain | top1 acc | testing protocol | config | ckpt | | ||
| :---------------------: | :--------: | :------: | :---------: | :------: | :--------------: | :-------------------------------------------------------------: | :-----------------------------------------------------------: | | ||
| uniform 8 | 224x224 | UMT-B | Kinetics710 | 87.33 | 4 clips x 3 crop | [config](./configs/umt-base-p16-res224_kinetics710-pre-ft_u8_k400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/projects/umt/umt-base-p16-res224_kinetics710-pre-ft_u8_k400-rgb/umt-base-p16-res224_kinetics710-pre-ft_u8_k400-rgb.pth) | | ||
| uniform 8 | 224x224 | UMT-L | Kinetics710 | 90.21 | 4 clips x 3 crop | [config](./configs/umt-large-p16-res224_kinetics710-pre-ft_u8_k400-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/projects/umt/umt-large-p16-res224_kinetics710-pre-ft_u8_k400-rgb/umt-large-p16-res224_kinetics710-pre-ft_u8_k400-rgb.pth) | | ||
|
||
### Kinetics700 | ||
|
||
| frame sampling strategy | resolution | backbone | pretrain | top1 acc | testing protocol | config | ckpt | | ||
| :---------------------: | :--------: | :------: | :---------: | :------: | :--------------: | :-------------------------------------------------------------: | :-----------------------------------------------------------: | | ||
| uniform 8 | 224x224 | UMT-B | Kinetics710 | 77.95 | 4 clips x 3 crop | [config](./configs/umt-base-p16-res224_kinetics710-pre-ft_u8_k700-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/projects/umt/umt-base-p16-res224_kinetics710-pre-ft_u8_k700-rgb/umt-base-p16-res224_kinetics710-pre-ft_u8_k700-rgb.pth) | | ||
| uniform 8 | 224x224 | UMT-L | Kinetics710 | 82.79 | 4 clips x 3 crop | [config](./configs/umt-large-p16-res224_kinetics710-pre-ft_u8_k700-rgb.py) | [ckpt](https://download.openmmlab.com/mmaction/v1.0/projects/umt/umt-large-p16-res224_kinetics710-pre-ft_u8_k700-rgb/umt-large-p16-res224_kinetics710-pre-ft_u8_k700-rgb.pth) | | ||
|
||
## Citation | ||
|
||
<!-- Replace to the citation of the paper your project refers to. --> | ||
|
||
```bibtex | ||
@article{li2023unmasked, | ||
title={Unmasked teacher: Towards training-efficient video foundation models}, | ||
author={Li, Kunchang and Wang, Yali and Li, Yizhuo and Wang, Yi and He, Yinan and Wang, Limin and Qiao, Yu}, | ||
journal={arXiv preprint arXiv:2303.16058}, | ||
year={2023} | ||
} | ||
``` |
82 changes: 82 additions & 0 deletions
82
projects/umt/configs/umt-base-p16-res224_kinetics710-pre-ft_u8_k400-rgb.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
custom_imports = dict(imports='models') | ||
|
||
# model settings | ||
model = dict( | ||
type='Recognizer3D', | ||
backbone=dict( | ||
type='UMTViT', | ||
patch_size=16, | ||
embed_dim=768, | ||
depth=12, | ||
num_heads=12, | ||
mlp_ratio=4, | ||
all_frames=8, | ||
qkv_bias=True), | ||
cls_head=dict( | ||
type='TimeSformerHead', | ||
num_classes=400, | ||
in_channels=768, | ||
average_clips='prob'), | ||
data_preprocessor=dict( | ||
type='ActionDataPreprocessor', | ||
mean=[114.75, 114.75, 114.75], | ||
std=[57.375, 57.375, 57.375], | ||
format_shape='NCTHW')) | ||
|
||
# dataset settings | ||
dataset_type = 'VideoDataset' | ||
data_root_val = 'data/kinetics400/videos_val' | ||
ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt' | ||
|
||
file_client_args = dict(io_backend='disk') | ||
|
||
test_pipeline = [ | ||
dict(type='DecordInit', **file_client_args), | ||
dict(type='UniformSample', clip_len=8, num_clips=4, test_mode=True), | ||
dict(type='DecordDecode'), | ||
dict(type='Resize', scale=(-1, 224)), | ||
dict(type='ThreeCrop', crop_size=224), | ||
dict(type='FormatShape', input_format='NCTHW'), | ||
dict(type='PackActionInputs') | ||
] | ||
|
||
test_dataloader = dict( | ||
batch_size=8, | ||
num_workers=16, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
ann_file=ann_file_test, | ||
data_prefix=dict(video=data_root_val), | ||
pipeline=test_pipeline, | ||
test_mode=True)) | ||
|
||
test_evaluator = dict(type='AccMetric') | ||
test_cfg = dict(type='TestLoop') | ||
|
||
default_scope = 'mmaction' | ||
|
||
default_hooks = dict( | ||
runtime_info=dict(type='RuntimeInfoHook'), | ||
timer=dict(type='IterTimerHook'), | ||
logger=dict(type='LoggerHook', interval=20, ignore_last=False), | ||
param_scheduler=dict(type='ParamSchedulerHook'), | ||
checkpoint=dict( | ||
type='CheckpointHook', interval=1, save_best='auto', max_keep_ckpts=5), | ||
sampler_seed=dict(type='DistSamplerSeedHook'), | ||
sync_buffers=dict(type='SyncBuffersHook')) | ||
|
||
env_cfg = dict( | ||
cudnn_benchmark=False, | ||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), | ||
dist_cfg=dict(backend='nccl')) | ||
|
||
log_processor = dict(type='LogProcessor', window_size=20, by_epoch=True) | ||
|
||
vis_backends = [dict(type='LocalVisBackend')] | ||
visualizer = dict(type='ActionVisualizer', vis_backends=vis_backends) | ||
|
||
log_level = 'INFO' | ||
load_from = None | ||
resume = False |
82 changes: 82 additions & 0 deletions
82
projects/umt/configs/umt-base-p16-res224_kinetics710-pre-ft_u8_k700-rgb.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
custom_imports = dict(imports='models') | ||
|
||
# model settings | ||
model = dict( | ||
type='Recognizer3D', | ||
backbone=dict( | ||
type='UMTViT', | ||
patch_size=16, | ||
embed_dim=768, | ||
depth=12, | ||
num_heads=12, | ||
mlp_ratio=4, | ||
all_frames=8, | ||
qkv_bias=True), | ||
cls_head=dict( | ||
type='TimeSformerHead', | ||
num_classes=700, | ||
in_channels=768, | ||
average_clips='prob'), | ||
data_preprocessor=dict( | ||
type='ActionDataPreprocessor', | ||
mean=[114.75, 114.75, 114.75], | ||
std=[57.375, 57.375, 57.375], | ||
format_shape='NCTHW')) | ||
|
||
# dataset settings | ||
dataset_type = 'VideoDataset' | ||
data_root_val = 'data/kinetics700/videos_val' | ||
ann_file_test = 'data/kinetics700/kinetics700_val_list_videos.txt' | ||
|
||
file_client_args = dict(io_backend='disk') | ||
|
||
test_pipeline = [ | ||
dict(type='DecordInit', **file_client_args), | ||
dict(type='UniformSample', clip_len=8, num_clips=4, test_mode=True), | ||
dict(type='DecordDecode'), | ||
dict(type='Resize', scale=(-1, 224)), | ||
dict(type='ThreeCrop', crop_size=224), | ||
dict(type='FormatShape', input_format='NCTHW'), | ||
dict(type='PackActionInputs') | ||
] | ||
|
||
test_dataloader = dict( | ||
batch_size=8, | ||
num_workers=16, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
ann_file=ann_file_test, | ||
data_prefix=dict(video=data_root_val), | ||
pipeline=test_pipeline, | ||
test_mode=True)) | ||
|
||
test_evaluator = dict(type='AccMetric') | ||
test_cfg = dict(type='TestLoop') | ||
|
||
default_scope = 'mmaction' | ||
|
||
default_hooks = dict( | ||
runtime_info=dict(type='RuntimeInfoHook'), | ||
timer=dict(type='IterTimerHook'), | ||
logger=dict(type='LoggerHook', interval=20, ignore_last=False), | ||
param_scheduler=dict(type='ParamSchedulerHook'), | ||
checkpoint=dict( | ||
type='CheckpointHook', interval=1, save_best='auto', max_keep_ckpts=5), | ||
sampler_seed=dict(type='DistSamplerSeedHook'), | ||
sync_buffers=dict(type='SyncBuffersHook')) | ||
|
||
env_cfg = dict( | ||
cudnn_benchmark=False, | ||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), | ||
dist_cfg=dict(backend='nccl')) | ||
|
||
log_processor = dict(type='LogProcessor', window_size=20, by_epoch=True) | ||
|
||
vis_backends = [dict(type='LocalVisBackend')] | ||
visualizer = dict(type='ActionVisualizer', vis_backends=vis_backends) | ||
|
||
log_level = 'INFO' | ||
load_from = None | ||
resume = False |
82 changes: 82 additions & 0 deletions
82
projects/umt/configs/umt-large-p16-res224_kinetics710-pre-ft_u8_k400-rgb.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
custom_imports = dict(imports='models') | ||
|
||
# model settings | ||
model = dict( | ||
type='Recognizer3D', | ||
backbone=dict( | ||
type='UMTViT', | ||
patch_size=16, | ||
embed_dim=1024, | ||
depth=24, | ||
num_heads=16, | ||
mlp_ratio=4, | ||
all_frames=8, | ||
qkv_bias=True), | ||
cls_head=dict( | ||
type='TimeSformerHead', | ||
num_classes=400, | ||
in_channels=1024, | ||
average_clips='prob'), | ||
data_preprocessor=dict( | ||
type='ActionDataPreprocessor', | ||
mean=[114.75, 114.75, 114.75], | ||
std=[57.375, 57.375, 57.375], | ||
format_shape='NCTHW')) | ||
|
||
# dataset settings | ||
dataset_type = 'VideoDataset' | ||
data_root_val = 'data/kinetics400/videos_val' | ||
ann_file_test = 'data/kinetics400/kinetics400_val_list_videos.txt' | ||
|
||
file_client_args = dict(io_backend='disk') | ||
|
||
test_pipeline = [ | ||
dict(type='DecordInit', **file_client_args), | ||
dict(type='UniformSample', clip_len=8, num_clips=4, test_mode=True), | ||
dict(type='DecordDecode'), | ||
dict(type='Resize', scale=(-1, 224)), | ||
dict(type='ThreeCrop', crop_size=224), | ||
dict(type='FormatShape', input_format='NCTHW'), | ||
dict(type='PackActionInputs') | ||
] | ||
|
||
test_dataloader = dict( | ||
batch_size=8, | ||
num_workers=16, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
ann_file=ann_file_test, | ||
data_prefix=dict(video=data_root_val), | ||
pipeline=test_pipeline, | ||
test_mode=True)) | ||
|
||
test_evaluator = dict(type='AccMetric') | ||
test_cfg = dict(type='TestLoop') | ||
|
||
default_scope = 'mmaction' | ||
|
||
default_hooks = dict( | ||
runtime_info=dict(type='RuntimeInfoHook'), | ||
timer=dict(type='IterTimerHook'), | ||
logger=dict(type='LoggerHook', interval=20, ignore_last=False), | ||
param_scheduler=dict(type='ParamSchedulerHook'), | ||
checkpoint=dict( | ||
type='CheckpointHook', interval=1, save_best='auto', max_keep_ckpts=5), | ||
sampler_seed=dict(type='DistSamplerSeedHook'), | ||
sync_buffers=dict(type='SyncBuffersHook')) | ||
|
||
env_cfg = dict( | ||
cudnn_benchmark=False, | ||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), | ||
dist_cfg=dict(backend='nccl')) | ||
|
||
log_processor = dict(type='LogProcessor', window_size=20, by_epoch=True) | ||
|
||
vis_backends = [dict(type='LocalVisBackend')] | ||
visualizer = dict(type='ActionVisualizer', vis_backends=vis_backends) | ||
|
||
log_level = 'INFO' | ||
load_from = None | ||
resume = False |
Oops, something went wrong.