Training data-efficient image transformers & distillation through attention
Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. However, these visual transformers are pre-trained with hundreds of millions of images using an expensive infrastructure, thereby limiting their adoption. In this work, we produce a competitive convolution-free transformer by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop evaluation) on ImageNet with no external data. More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks. We share our code and models.
Predict image
from mmpretrain import inference_model
predict = inference_model('deit-tiny_4xb256_in1k', 'demo/bird.JPEG')
print(predict['pred_class'])
print(predict['pred_score'])
Use the model
import torch
from mmpretrain import get_model
model = get_model('deit-tiny_4xb256_in1k', pretrained=True)
inputs = torch.rand(1, 3, 224, 224)
out = model(inputs)
print(type(out))
# To extract features.
feats = model.extract_feat(inputs)
print(type(feats))
Train/Test Command
Prepare your dataset according to the docs.
Train:
python tools/train.py configs/deit/deit-tiny_4xb256_in1k.py
Test:
python tools/test.py configs/deit/deit-tiny_4xb256_in1k.py https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.pth
| Model | Pretrain | Params (M) | Flops (G) | Top-1 (%) | Top-5 (%) | Config | Download |
|---|---|---|---|---|---|---|---|
deit-tiny_4xb256_in1k |
From scratch | 5.72 | 1.26 | 74.50 | 92.24 | config | model | log |
deit-tiny-distilled_3rdparty_in1k* |
From scratch | 5.91 | 1.27 | 74.51 | 91.90 | config | model |
deit-small_4xb256_in1k |
From scratch | 22.05 | 4.61 | 80.69 | 95.06 | config | model | log |
deit-small-distilled_3rdparty_in1k* |
From scratch | 22.44 | 4.63 | 81.17 | 95.40 | config | model |
deit-base_16xb64_in1k |
From scratch | 86.57 | 17.58 | 81.76 | 95.81 | config | model | log |
deit-base_3rdparty_in1k* |
From scratch | 86.57 | 17.58 | 81.79 | 95.59 | config | model |
deit-base-distilled_3rdparty_in1k* |
From scratch | 87.34 | 17.67 | 83.33 | 96.49 | config | model |
deit-base_224px-pre_3rdparty_in1k-384px* |
224px | 86.86 | 55.54 | 83.04 | 96.31 | config | model |
deit-base-distilled_224px-pre_3rdparty_in1k-384px* |
224px | 87.63 | 55.65 | 85.55 | 97.35 | config | model |
*Models with * are converted from the official repo. The config files of these models are only for inference. We haven't reproduce the training results.*
MMPretrain doesn't support training the distilled version DeiT.
And we provide distilled version checkpoints for inference only.
@InProceedings{pmlr-v139-touvron21a,
title = {Training data-efficient image transformers & distillation through attention},
author = {Touvron, Hugo and Cord, Matthieu and Douze, Matthijs and Massa, Francisco and Sablayrolles, Alexandre and Jegou, Herve},
booktitle = {International Conference on Machine Learning},
pages = {10347--10357},
year = {2021},
volume = {139},
month = {July}
}