Pyramid Vision Transformer
Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions
模型简介
PVT是一种无需卷积操作的用于密集预测的通用主干网络。PVT在Transformer中引入金字塔结构,以便为密集的预测任务生成多尺度特征图。PVT使用逐步缩小策略,通过块嵌入层来控制特征图的规模,并提出了一个SRA层(spatial-reduction attention, 空间缩减注意)来替代编码器中传统的多头注意力层,大大减少了计算/内存开销。
性能指标
Model |
Context |
Top-1 (%) |
Top-5 (%) |
Params (M) |
Train T. |
Infer T. |
Download |
Config |
Log |
PVT_tiny |
D910x8-G |
74.92 |
|
|
433s/epoch |
16ms/step |
model |
cfg |
log |
PVT_small |
D910x8-G |
79.66 |
|
|
538s/epoch |
30ms/step |
model |
cfg |
log |
PVT_medium |
D910x8-G |
81.82 |
|
|
766s/epoch |
47ms/step |
model |
cfg |
log |
PVT_large |
D910x8-G |
81.75 |
|
|
1074s/epoch |
67ms/step |
model |
cfg |
log |
Notes
- All models are trained on ImageNet-1K training set and the top-1 accuracy is reported on the validatoin set.
- Context: GPU_TYPE x pieces - G/F, G - graph mode, F - pynative mode with ms function.
示例
训练
configs文件夹中列出了mindcv套件所包含的模型的各个规格的yaml配置文件(在ImageNet数据集上训练和验证的配置)。
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
mpirun -n 8 python train.py -c configs/pvt/pvt_tiny_ascend.yaml --data_dir /path/to/imagenet
详细的可调参数及其默认值可以在config.py中查看。
验证
-
下面是使用validate.py
文件验证pvt_tiny的预训练模型的精度的示例。
python validate.py --model=pvt_tiny --dataset=imagenet --val_split=val --pretrained
-
下面是使用validate.py
文件验证pvt_tiny的自定义参数文件的精度的示例。
python validate.py --model=pvt_tiny --dataset=imagenet --val_split=val --ckpt_path='./ckpt/pvt-tiny-best.ckpt'