Focal Transformer for Image Classification
This explains the details for training Focal Transformers on image classification. We have the same instructions to Swin Transfomer
Usage
Install
git clone https://github.com/microsoft/Focal-Transformer.git
cd Focal-Transformer
- Create a conda virtual environment and activate it:
conda create -n focal python=3.7 -y
conda activate focal
conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=10.1 -c pytorch
pip install timm==0.3.2
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
- Install other requirements:
pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8
Data preparation
We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to
load data:
-
For standard folder dataset, move validation images to labeled sub-folders. The file structure should look like:
$ tree data
imagenet
├── train
│ ├── class1
│ │ ├── img1.jpeg
│ │ ├── img2.jpeg
│ │ └── ...
│ ├── class2
│ │ ├── img3.jpeg
│ │ └── ...
│ └── ...
└── val
├── class1
│ ├── img4.jpeg
│ ├── img5.jpeg
│ └── ...
├── class2
│ ├── img6.jpeg
│ └── ...
└── ...
-
To boost the slow speed when reading images from massive small files, we also support zipped ImageNet, which includes
four files:
train.zip
, val.zip
: which store the zipped folder for train and validate splits.
train_map.txt
, val_map.txt
: which store the relative path in the corresponding zip file and ground truth
label. Make sure the data folder looks like this:
$ tree data
data
└── ImageNet-Zip
├── train_map.txt
├── train.zip
├── val_map.txt
└── val.zip
$ head -n 5 data/ImageNet-Zip/val_map.txt
ILSVRC2012_val_00000001.JPEG 65
ILSVRC2012_val_00000002.JPEG 970
ILSVRC2012_val_00000003.JPEG 230
ILSVRC2012_val_00000004.JPEG 809
ILSVRC2012_val_00000005.JPEG 516
$ head -n 5 data/ImageNet-Zip/train_map.txt
n01440764/n01440764_10026.JPEG 0
n01440764/n01440764_10027.JPEG 0
n01440764/n01440764_10029.JPEG 0
n01440764/n01440764_10040.JPEG 0
n01440764/n01440764_10042.JPEG 0
Evaluation
To evaluate a pre-trained Focal Transformer
on ImageNet val, run:
python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py --eval \
--cfg <config-file> --resume <checkpoint> --data-path <imagenet-path>
For example, to evaluate the Focal-Base
with a single GPU:
python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --eval \
--cfg configs/focal_base_patch4_window7_224.yaml --resume focal-base-is224-ws7.pth --data-path <imagenet-path>
Training from scratch
To train a Focal Transformer
on ImageNet from scratch, run:
python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345 main.py \
--cfg <config-file> --data-path <imagenet-path> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]
Notes:
- To use zipped ImageNet instead of folder dataset, add
--zip
to the parameters.
- To cache the dataset in the memory instead of reading from files every time, add
--cache-mode part
, which will
shard the dataset into non-overlapping pieces for different GPUs and only load the corresponding one for each GPU.
- When GPU memory is not enough, you can try the following suggestions:
- Use gradient accumulation by adding
--accumulation-steps <steps>
, set appropriate <steps>
according to your need.
- Use gradient checkpointing by adding
--use-checkpoint
, which can save a lot of GPU memory.
Please refer to this page for more details.
- We recommend using multi-node with more GPUs for training very large models, a tutorial can be found
in this page.
- To change config options in general, you can use
--opts KEY1 VALUE1 KEY2 VALUE2
, e.g.,
--opts TRAIN.EPOCHS 100 TRAIN.WARMUP_EPOCHS 5
will change total epochs to 100 and warm-up epochs to 5.
- For additional options, see config and run
python main.py --help
to get detailed message.
For example, to train Focal Transformer
with 8 GPU on a single node for 300 epochs, run:
Focal-Tiny
:
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
--cfg configs/focal_tiny_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 128
Focal-Small
:
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
--cfg configs/focal_small_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 128
Focal-Base
:
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
--cfg configs/focal_base_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 64 \
--accumulation-steps 2 [--use-checkpoint]
Throughput
To measure the throughput, run:
python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py \
--cfg <config-file> --data-path <imagenet-path> --batch-size 64 --throughput --amp-opt-level O0