Skip to content
/ AGST Public

Pytorch implementation of the AGST tracker

Notifications You must be signed in to change notification settings

Yang428/AGST

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

14 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

AGST - Attention-based Gating network for Segmentation Tracking

Publication

Yijin Yang and Xiaodong Gu. Attention-based Gating Network for Robust Segmentation Tracking. TCSVT, 2024.

Abstract

Visual object tracking is a challenging task that aims to accurately estimate the scale and position of a designated target. Recently, segmentation networks have proven effective in visual tracking, producing outstanding results for target scale estimation. However, segmentation-based trackers still lack robustness due to the presence of similar distractors. To mitigate this issue, we propose an Attention-based Gating Network (AGNet) that produces gating weights to diminish the impact of feature maps linked to similar distractors. Subsequently, we incorporate the AGNet into the segmentation-based tracking paradigm to achieve accurate and robust tracking. Specifically, the AGNet utilizes three cascading Multi-Head Cross-Attention (MHCA) modules to generate gating weights that govern the generation of feature maps in the baseline tracker. The proficiency of the MHCA in modeling global semantic information effectively suppresses feature maps associated with similar distractors. Additionally, we introduce a distractor-aware training strategy that leverages distractor masks to train our model. To alleviate the issue of partial occlusion, we introduce a box refinement module to enhance the accuracy of the predicted target box. Comprehensive experiments conducted on 11 challenging tracking benchmarks show that our approach significantly surpasses the baseline tracker across all metrics and achieves excellent results on multiple tracking benchmarks.

Running Environments

  • Pytorch 1.1.0, Python 3.6.12, Cuda 9.0, torchvision 0.3.0, cudatoolkit 9.0, Matlab R2016b.
  • Ubuntu 16.04, NVIDIA GeForce GTX 1080Ti.

Installation

The instructions have been tested on an Ubuntu 16.04 system. In case of issues, we refer to these two links 1 and 2 for details.

Clone the GIT repository

git clone https://github.com/Yang428/AGST.git.

Install dependent libraries

Run the installation script 'install.sh' to install all dependencies. We refer to this link for step-by-step instructions.

bash install.sh conda_install_path pytracking

Or step by step install

conda create -n pytracking python=3.6
conda activate pytracking
conda install -y pytorch=1.1.0 torchvision=0.3.0 cudatoolkit=9.0 -c pytorch
conda install -y matplotlib=2.2.2
conda install -y pandas
pip install opencv-python
pip install tensorboardX
conda install -y cython
pip install pycocotools
pip install jpeg4py 
sudo apt-get install libturbojpeg

Or copy my environment directly.

You can download the packed conda environment from the Baidu cloud link, the extraction code is 'qjl2'.

Download the pre-trained networks

You can download the models from the Baidu cloud link, the extraction code is 'thn4'. Then put the model files 'SegmNet.pth.tar' and 'IoUnet.pth.tar' to the subfolder 'pytracking/networks'.

Testing the tracker

There are the raw resullts on eight datasets.

  1. Download the testing datasets OTB-100, LaSOT, Got-10k, TrackingNet, VOT2016, VOT2018, VOT2019 and VOT2020 from the following Baidu cloud links.
  • Got-10k, the extraction code is '78hq'.
  • TrackingNet, the extraction code is '5pj8'.
  • VOT2016, the extraction code is '8f6w'.
  • VOT2018, the extraction code is 'jsgt'.
  • VOT2020, the extraction code is 'kdag'.
  • OTB100, the extraction code is '9x8q'.
  • UAV123, the extraction code is 'vp4r'.
  • LaSOT.
  • NFS, the extraction code is 'gc7u'.
  • TCL128, the extraction code is '1h83'.
  • Or you can download almost all tracking datasets from this web link.
  1. Change the following paths to you own paths.
Network path: pytracking/parameters/agst/agst.py  params.segm_net_path.
Results path: pytracking/evaluation/local.py  settings.network_path, settings.results_path, dataset_path.
  1. Run the AGST tracker on Got10k, TrackingNet, OTB100, UAV123, LaSOT, NFS and TCL128 datasets.
cd pytracking
python run_experiment.py myexperiments got10k
python run_experiment.py myexperiments trackingnet
python run_experiment.py myexperiments otb
python run_experiment.py myexperiments uav
python run_experiment.py myexperiments lasot
python run_experiment.py myexperiments nfs
python run_experiment.py myexperiments tpl

Evaluation on VOT16 and VOT18 using Matlab R2016b

We provide a VOT Matlab toolkit integration for the AGST tracker. There is the tracker_AGST.m Matlab file in the 'pytracking/utils', which can be connected with the toolkit. It uses the 'pytracking/vot_wrapper.py' script to integrate the tracker to the toolkit.

Evaluation on VOT2020 and VOT2021 using Python Toolkit

We provide a VOT Python toolkit integration for the AGST tracker. There is the trackers.ini setting file in the 'pytracking/utils', which can be connected with the toolkit. It uses the 'pytracking/vot20_wrapper.py' script to integrate the tracker to the toolkit.

cd pytracking/workspace_vot2020
pip install git+https://github.com/votchallenge/vot-toolkit-python
vot initialize <vot2020> --workspace ./workspace_vot2020/
vot evaluate AGST
vot analysis --workspace ./workspace_vot2020/AGST

Training the networks

The AGST network is trained on the YouTube VOS, GOT-10K and TrackingNet datasets. Download the VOS training dataset (2018 version) and copy the files vos-list-train.txt and vos-list-val.txt from ltr/data_specs to the training directory of the VOS dataset. Download the bounding boxes from this link and copy them to the corresponding training sequence directories.

  1. Download the YouTube VOS dataset from this link.

  2. Download the GOT-10K dataset from this link.

  3. Download the TrackingNet dataset from this link.

  4. Download the pre-generated masks of GOT-10K and TrackingNet from this link. We refer to this link for more instructions.

  5. Change the following paths to you own paths.

Workspace: ltr/admin/local.py  workspace_dir.
Dataset: ltr/admin/local.py  vos_dir.
  1. Taining the AGST network
cd ltr
python run_training.py segm segm_default
then enable the distractor-aware loss in the 'AGST\ltr\train_settings\segm\segm_default.py'
run additional 5 epochs

Acknowledgement

We would like to thank the author Martin Danelljan of pytracking and the author Alan Lukežič of D3S.

BibTex citation:

@ARTICLE{Yijin2024,
title = {Attention-based Gating Network for Robust Segmentation Tracking},
author = {Yijin, Yang. and Xiaodong, Gu.},
journal = {TCSVT},
volume = {},
pages = {},
year = {2024},
doi = {10.1109/TCSVT.2024.3460400}
}

About

Pytorch implementation of the AGST tracker

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published