Home

Awesome

Dataset Distillation by Self-Adaptive Trajectory Matching

gz2_1ipc_aug_noise

This is the official repository for paper Discovering Galaxy Features via Dataset Distillation. Our work contains the experiment code for galaxy dataset distillation and Self-Adaptive Trajectory Matching (STM) algorithm, an improved version of Matching Training Trajectory (MTT).

Getting Started

Clone the repository and create a python environment with pytorch, at least one cuda supported GPU is required.

Generating Expert Trajectories

Before doing any distillation, you'll need to generate some expert trajectories using buffer.py. The following command will train 10 ConvNet models on CIFAR-10 with ZCA whitening for 50 epochs each:

python buffer.py --dataset=CIFAR10 --model=ConvNet --train_epochs=50 --num_experts=10 --zca --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

Note that experts need only be trained once and can be re-used for multiple distillation experiments. The example trains only 10 experts for fast demo purpose, --num_experts=10. To get our reported results, we train 100 experts for each dataset.

Available Datasets

GZoo2

Galaxy Zoo 2 (GZ2) is a survey based dataset. For better performance, we simplify the classification tree and build a curated version Based on GZ2 Table 5, which contains 243500 images. The simplified classification tree is shown below.

Classification Tree

We sort the confidence of galaxies in descending order and pick the top 600 confident images for each class to form a curated dataset. The curated dataset is divided into 500 training and 100 testing images per class, result of total 4500 training image and 900 testing images. The curated dataset is named as GZoo2.

Download our dataset and save it in {path_to_dataset}.

GZoo2_aug

GZoo2_aug is a dataset generated by applying rotational augmentation on GZoo2. Every training image in GZoo2 is rotated for 36 degree 10 times. GZoo2_aug contains 45000 training images and 900 testing images.

Download our dataset and save it in {path_to_dataset}.

Distill Knowledge from Expert Trajectories

The following command will then use the buffers we generate before to distill CIFAR-10 a 1 image per class synthetic dataset:

python distill.py --dataset=CIFAR10 --ipc=1 --syn_steps=50 --zca --lr_img=1000 --lr_teacher=0.01 --lr_lr=0.01 --buffer_path={path_to_buffer_storage} --data_path={path_to_dataset}

Distillation Hyperparameter

DatasetImg/ClsSynthetic Steps <br />$(N)$Learning Rate <br />(Pixels)Initial Step Size <br />$(\alpha)$Learning Rate <br />(Step Size)
15010000.010.01
CIFAR-10103010000.010.01
503010000.010.01
12010000.010.01
CIFAR-100102010000.010.01
50----
150100000.00010.01
GZoo21020100000.00010.01
50----

Distillation Parameters

ArgumentDescription
--datasetDataset Name: str
--ipcDistillation Image Per Class: int
--syn_stepsSynthetic Step: int
--lr_teacherInitial student network learning rate (step size): float
--lr_lrLearning rate for updating step size: float
--lr_imgLearning rate for updating synthetic image: int
--pix_init=realSynthetic image initialization. Choose from: {noise/real}

Optional Parameters

ArgumentDescription
--wandb_name=Job_1Customize your wandb job name: str
--max_duration=1000Maximum duration for one stage of Hypothesis testing $(T_{max})$: int
--sigma=5The number of sigma $(\lambda\sigma)$ as the critical value of Hypothesis testing: float
--buffer_path={path}Buffer path: str
--data_path={path_to_dataset}Dataset Path: str
--eval_it=100Evaluation interval, smaller value will evaluate more frequently: int
--init_epoch=1Initial trajectory matching epoch pool size, (Default is 1): [0, init_epoch)
--expert_epochs=1MTT parameter. We fix it to 1 for STM: int
--max_start_epoch=29Maximum epoch for stage distillation: int. Use the train_epochs value during expert training.
--Iteration=10000Maximum global iteration, in case distillation run too long: int

WanDB Guide

Stagewise-MTT algorithm generates an independent loss and test_loss for each individual starting epoch. When the values log into WanDB server, each value will be plotted in an individual plot. In this guide, we will show you how to plot them in a same plot for better visualization. For example:

WanDB