Awesome
MultiGrain
<img src="https://user-images.githubusercontent.com/5989894/55801088-8f497280-5ad5-11e9-8759-310cd1da41f4.jpg" height="140">MultiGrain is a neural network architecture that solves both image classification and image retrieval tasks.
The method is described in "MultiGrain: a unified image embedding for classes and instances" (arXiv link).
BibTeX reference:
@ARTICLE{2019arXivMultiGrain,
author = {Berman, Maxim and J{\'e}gou, Herv{\'e} and Vedaldi Andrea and
Kokkinos, Iasonas and Douze, Matthijs},
title = "{{MultiGrain}: a unified image embedding for classes and instances}",
journal = {arXiv e-prints},
year = "2019",
month = "Feb",
}
Please cite it if you use it.
Installation
The MultiGrain code requires
- Python 3.5 or higher
- PyTorch 1.0 or higher
and the requirements highlighted in requirements.txt
The requirements can be installed:
- Ether by setting up a dedicated conda environment:
conda env create -f environment.yml
followed bysource activate multigrain
- Or with pip:
pip install -r requirements.txt
Using the code
Extracting features with pre-trained networks
We provide pre-trained networks with ResNet-50 trunks for the following settings (top-1 accuracies given at scale 224):
λ | p | augmentation | top-1 | weights |
---|---|---|---|---|
1 | 1 | full | 76.8 | joint_1B_1.0.pth |
1 | 3 | full | 76.9 | joint_3B_1.0.pth |
0.5 | 1 | full | 77.0 | joint_1B_0.5.pth |
0.5 | 3 | full | 77.4 | joint_3B_0.5.pth |
0.5 | 3 | autoaugment | 78.2 | joint_3BAA_0.5.pth |
We provide fine-tuned networks for scales bigger than 224, as described in the Supplementary E. Only the pooling coefficient is fine-tuned:
network | scale | p | top-1 | weights |
---|---|---|---|---|
NASNet-A-Mobile | 350 px | 1.7 | 75.1 | joint_1B_1.0.pth |
SENet154 | 400 px | 1.6 | 83.0 | joint_3B_1.0.pth |
PNASNet-5-Large | 500 px | 1.7 | 83.6 | joint_1B_0.5.pth |
To load a network, use the following PyTorch code:
import torch
from multigrain.lib import get_multigrain
net = get_multigrain('resnet50')
checkpoint = torch.load('base_1B_1.0.pth')
net.load_state_dict(checkpoint['model_state'])
The network takes images in any resolution.
A normalization pre-processing step is used, with mean [0.485, 0.456, 0.406]
.
and standard deviation [0.229, 0.224, 0.225]
.
The pretrained weights do not include whitening of the features (important for retrieval), which are specific to each evaluation scale; follow steps below to compute and apply a whitening.
Evaluation of the networks
scripts/evaluate.py
evaluates the network on standard benchmarks.
Classification results
Evaluate a network on ImageNet-val is straightforward using options from evaluate.py
. For instance the following command:
IMAGENET_PATH= # the path that contains the /val and /train image directories
python scripts/evaluate.py --expdir experiments/joint_3B_0.5/eval_p4_500 \
--imagenet-path $IMAGENET_PATH --input-size 500 --dataset imagenet-val \
--pooling-exponent 4 --resume-from joint_3B_0.5.pth
using the joint_3B_0.5.pth pretrained weights, should reproduce the top-1/top5 results of 78.6%/94.4% given in the article in Table 2 for ResNet-50 MultiGrain p=3, λ=0.5 and p*=4 scale s*=500.
Retrieval results
The implementation of the evaluation on the retrieval benchmarks in evaluate.py
is in progress, but one may already use the dataloaders
implemented in datasets/retrieval.py for this purpose.
Training
The training is performed in three steps. See help (-h
flag) for detailed parameter list of each script.
Only the initial joint training script benefits from multi-gpu hardware, the remaining scripts are not parallelized.
Joint training
scripts/train.py trains a MultiGrain architecture.
Important parameters:
--repeated-augmentations
: number of repeated augmentations in the batches, N=3 was used in our joint trainings; N=1 is vanilla uniform sampling.--pooling-exponent
: pooling exponent in GeM pooling, p=1: vanilla average pooling.--classif-weight
: weighting factor between margin loss and classification loss (parameter λ in paper)
Other useful parameters:
--expdir
: dedicated directory for the experiments--resume-from
: takes either an expdir or a model checkpoint file to restore from--pretrained-backbone
: initialized backbone weights from model zoo
Input size fine-tuning of GeM exponent
scripts/finetune_p.py determines the optimal p* for a given input resolution by fine-tuning (see supplementary E. in paper for details). Alternatively one may use cross-validation to determine p*, as done in the main article.
Whitening of the retrieval features
scripts/whiten.py computes a PCA whitening and modifies the network accordingly,
integrating the reversed transformation in the fully-connected classification layer as described in the article.
The scripts takes a list and directory of whitening images; the list given in data/whiten.txt
is relative to the multimedia commons file structure.
Example training procedure
For example, the results with p=3 and λ=0.5 at scale s*=500 can be obtained with
# train network
python scripts/train.py --expdir experiments/joint_3B_0.5 --repeated-augmentations 3 \
--pooling-exponent 3 --classif-weight 0.5 --imagenet-path $IMAGENET_PATH
# fine-tune p*
python scripts/finetune_p.py --expdir experiments/joint_3B_0.5/finetune500 \
--resume-from experiments/joint_3B_0.5 --input-size 500 --imagenet-path $IMAGENET_PATH
# whitening
python scripts/whiten.py --expdir experiments/joint_3B_0.5/finetune500_whitened \
--resume-from experiments/joint_3B_0.5/finetune500 --input-size 500 --whiten-path $WHITEN_PATH
Fine-tuning existing network
In appendix E. we report fine-tuning results on several pretrained networks.
This experience can be reproduced using the finetune_p.py script.
For example, in the case of SENet154
at scale s*=450, the following command should yield 83.1 top-1 accuracy with p*=1.6:
python scripts/finetune_p.py --expdir experiments/se154/finetune450 \
--pretrained-backbone --imagenet-path $IMAGENET_PATH --input-size 450 --backbone senet154 \
--no-validate-first
Contributing
See the CONTRIBUTING file for how to help out.
License
MultiGrain is CC BY-NC 4.0 licensed, as found in the LICENSE file.
The AutoAugment implementation is based on https://github.com/DeepVoltaire/AutoAugment. The Distance Weighted Sampling and margin loss implementation is based on the authors implementation https://github.com/chaoyuaw/sampling_matters.