Awesome
Learning to Weight Samples for Dynamic Early-exiting Networks (ECCV 2022)
Yizeng Han* , Yifan Pu*, Zihang Lai, Chaofei Wang, Shiji Song, Junfeng Cao, Wenhui Huang, Chao Deng, Gao Huang.
*: Equal contribution.
Introduction
This repository contains the implementation of the paper, Learning to Weight Samples for Dynamic Early-exiting Networks (ECCV 2022). The proposed method adopts a weight prediction network to weight the training loss of different samples for dynamic early-exiting networks, such as MSDNet and RANet, and improves their performance in the dynamic early exiting scenario.
Overall idea
<img src="./figs/fig1.jpg" alt="fig1" style="zoom:60%;" />Training pipeline
Gradient flow of the meta-learning algorithm
Usage
Dependencies
- Python: 3.8
- Pytorch: 1.10.0
- Torchvision: 0.11.0
Scripts
- Train a MSDNet (5 exits, step=4) on ImageNet:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python tools/main_imagenet_DDP.py \
--train_url YOUR_SAVE_PATH \
--data_url YOUR_DATA_PATH --data ImageNet --workers 64 --seed 0 \
--arch msdnet --nBlocks 5 --stepmode even --step 4 --base 4 --nChannels 32 --growthRate 16 --grFactor 1-2-4-4 --bnFactor 1-2-4-4 \
--meta_net_hidden_size 500 --meta_net_num_layers 1 --meta_interval 100 --meta_lr 1e-4 --meta_weight_decay 1e-4 \
--epsilon 0.3 --target_p_index 15 --meta_net_input_type loss --constraint_dimension mat \
--epochs 100 --batch-size 4096 --lr 0.8 --lr-type cosine --print-freq 10
- Train a MSDNet (5 exits, step=4) on ImageNet (on high-flyer yinghuo cluster):
hfai python tools/main_imagenet_DDP_HF.py \
--train_url YOUR_SAVE_PATH \
--data_url YOUR_DATA_PATH --data ImageNet --workers 64 --seed 0 \
--arch msdnet --nBlocks 5 --stepmode even --step 4 --base 4 --nChannels 32 --growthRate 16 --grFactor 1-2-4-4 --bnFactor 1-2-4-4 \
--meta_net_hidden_size 500 --meta_net_num_layers 1 --meta_interval 100 --meta_lr 1e-4 --meta_weight_decay 1e-4 \
--epsilon 0.3 --target_p_index 15 --meta_net_input_type loss --constraint_dimension mat \
--epochs 100 --batch-size 4096 --lr 0.8 --lr-type cosine --print-freq 10 \
-- --nodes=1 --name=YOUR_EXPERIMENT_NAME
- Evaluate (anytime):
CUDA_VISIBLE_DEVICES=0 python tools/eval_imagenet.py \
--data ImageNet --batch-size 512 --workers 8 --seed 0 --print-freq 10 --evalmode anytime \
--arch msdnet --nBlocks 5 --stepmode even --step 4 --base 4 --nChannels 32 --growthRate 16 --grFactor 1-2-4-4 --bnFactor 1-2-4-4 \
--data_url YOUR_DATA_PATH \
--train_url YOUR_SAVE_PATH \
--evaluate_from YOUR_CKPT_PATH
- Evaluate (dynamic):
CUDA_VISIBLE_DEVICES=0 python tools/eval_imagenet.py \
--data ImageNet --batch-size 512 --workers 2 --seed 0 --print-freq 10 --evalmode dynamic \
--arch msdnet --nBlocks 5 --stepmode even --step 4 --base 4 --nChannels 32 --growthRate 16 --grFactor 1-2-4-4 --bnFactor 1-2-4-4 \
--data_url YOUR_DATA_PATH
--train_url YOUR_SAVE_PATH \
--evaluate_from YOUR_CKPT_PATH
- Train a MSDNet (5 exits) on CIFAR100:
CUDA_VISIBLE_DEVICES=0 python tools/main_cifar_DDP.py \
--train_url YOUR_SAVE_PATH \
--data_url YOUR_DATA_PATH --data cifar100 --workers 1 --seed 1 \
--arch msdnet --nBlocks 5 --stepmode lin_grow --step 1 --base 1 --nChannels 16 \
--meta_net_hidden_size 500 --meta_net_num_layers 1 --meta_interval 1 --meta_lr 1e-4 --meta_weight_decay 1e-4 \
--epsilon 0.8 --target_p_index 15 --meta_net_input_type loss --constraint_dimension col \
--epochs 300 --batch-size 1024 --lr 0.8 --lr-type cosine --print-freq 10
Results
- CIFAR-10 and CIFAR-100
- ImageNet
Pre-trained Models on ImageNet
model config | epochs | labelsmooth | acc_exit1 | acc_exit2 | acc_exit3 | acc_exit4 | acc_exit5 | Checkpoint Link |
---|---|---|---|---|---|---|---|---|
step=4 | 100 | N/A | 59.54 | 67.22 | 71.03 | 72.33 | 73.93 | Tsinghua Cloud / Google Drive |
step=6 | 100 | N/A | 60.05 | 69.13 | 73.33 | 75.19 | 76.30 | Tsinghua Cloud / Google Drive |
step=7 | 100 | N/A | 59.24 | 69.65 | 73.94 | 75.66 | 76.72 | Tsinghua Cloud / Google Drive |
step=4 | 300 | 0.1 | 61.64 | 67.89 | 71.61 | 73.82 | 75.03 | Tsinghua Cloud / Google Drive |
step=6 | 300 | 0.1 | 61.41 | 70.70 | 74.38 | 75.80 | 76.66 | Tsinghua Cloud / Google Drive |
step=7 | 300 | 0.1 | 60.94 | 71.88 | 75.13 | 76.03 | 76.82 | Tsinghua Cloud / Google Drive |
Contact
If you have any questions, please feel free to contact the authors.
Yizeng Han: hanyz18@mails.tsinghua.edu.cn, yizeng38@gmail.com.
Yifan Pu: pyf20@mails.tsinghua.edu.cn, yifanpu98@126.com.
Ackowledgements
We use the pytorch implementation of MSDNet-PyTorch, RANet-PyTorch and IMTA in our experiments.