Home

Awesome

Self-Classifier: Self-Supervised Classification Network

Official PyTorch implementation and pretrained models of the paper Self-Supervised Classification Network from ECCV 2022.

<p align="center"> <img src="graphics/Self-Classifier_arch.jpg" width="65%"> </p>

Self-Classifier architecture. Two augmented views of the same image are processed by a shared network comprised of a backbone (e.g. CNN) and a classifier (e.g. projection MLP + linear classification head). The cross-entropy of the two views is minimized to promote same class prediction while avoiding degenerate solutions by asserting a uniform prior on class predictions. The resulting model learns representations and discovers the underlying classes in a single-stage end-to-end unsupervised manner.

If you find this repository useful in your research, please cite:

@article{amrani2021self,
  title={Self-Supervised Classification Network},
  author={Amrani, Elad and Karlinsky, Leonid and Bronstein, Alex},
  journal={arXiv preprint arXiv:2103.10994},
  year={2021}
}

Pretrained Models

Download pretrained models here.

Setup

  1. Install Conda environment:

     conda env create -f ./environment.yml
    
  2. Install Apex with CUDA extension:

     export TORCH_CUDA_ARCH_LIST="7.0"  # see https://en.wikipedia.org/wiki/CUDA#GPUs_supported
     pip install git+git://github.com/NVIDIA/apex.git@4a1aa97e31ca87514e17c3cd3bbc03f4204579d0 --install-option="--cuda_ext"         
    

Training & Evaluation

Distributed training & evaluation is available via Slurm. See SBATCH scripts here.

IMPORTANT: set DATASET_PATH, EXPERIMENT_PATH and PRETRAINED_PATH to match your local paths.

Training

methodepochsNMIAMIARIACClinear probing top-1 acc.training script
Self-Classifier10071.249.226.137.372.4script
Self-Classifier20072.551.628.139.473.5script
Self-Classifier40072.952.328.840.274.2script
Self-Classifier80073.353.129.541.174.1script

NMI: Normalized Mutual Information, AMI: Adjusted Normalized Mutual Information, ARI: Adjusted Rand-Index and ACC: Unsupervised clustering accuracy. linear probing: training a supervised linear classifier on top of frozen self-supervised features.

Evaluation

Unsupervised Image Classification

datasetNMIAMIARIACCevaluation script
ImageNet 1K classes73.353.129.541.1script
ImageNet 10 superclasses (level #2 in hierarchy)74.054.330.985.7script
ImageNet 29 superclasses (level #3 in hierarchy)74.054.330.979.7script
ImageNet 128 superclasses (level #4 in hierarchy)74.054.330.971.8script
ImageNet 466 superclasses (level #5 in hierarchy)73.954.330.860.0script
ImageNet 591 superclasses (level #6 in hierarchy)74.155.332.146.7script
BREEDS Entity13 (ImageNet based)73.654.130.784.4script
BREEDS Entity30 (ImageNet based)72.953.429.881.0script
BREEDS Living17 (ImageNet based)67.251.826.490.8script
BREEDS Nonliving26 (ImageNet based)72.257.036.876.7script

NMI: Normalized Mutual Information, AMI: Adjusted Normalized Mutual Information, ARI: Adjusted Rand-Index and ACC: Unsupervised clustering accuracy.

K-Means Baselines Using Self-Supervised Pretrained Models

methodNMIAMIARIACCevaluation script
BarlowTwins68.848.314.733.2script
OBoW66.542.016.931.1script
DINO66.242.315.630.7script
MoCov266.645.312.030.6script
SwAV64.138.813.428.1script
SimSiam62.234.911.624.9script

NMI: Normalized Mutual Information, AMI: Adjusted Normalized Mutual Information, ARI: Adjusted Rand-Index and ACC: Unsupervised clustering accuracy. All methods are evaluated on ImageNet 1K classes with original pre-trained models - MoCov2, OBoW, SimSiam, SwAV. DINO and BarlowTwins use PyTorch Hub (i.e., no need for direct download).

Image Classification with Linear Models

For training a supervised linear classifier on a frozen backbone, run:

    sbatch ./scripts/lincls_eval.sh

    

Image Classification with kNN

For running K-nearest neighbor classifier on ImageNet validation set, run:

    sbatch ./scripts/knn_eval.sh

Transferring to Object Detection and Instance Segmentation

See ./detection.

Ablation study

For training the 100-epoch ablation study baseline, run:

    sbatch ./scripts/ablation/train_100ep.sh

For training any of the ablation study runs presented in the paper, run:

    sbatch ./scripts/ablation/<ablation_name>/<ablation_script>.sh
    

Qualitative Examples (classes predicted by Self-Classifier on ImageNet validation set)

<img src="graphics/grid_0.jpg" width="18%"> <img src="graphics/grid_1.jpg" width="18%"> <img src="graphics/grid_2.jpg" width="18%"> <img src="graphics/grid_3.jpg" width="18%"> <img src="graphics/grid_4.jpg" width="18%"> <img src="graphics/grid_5.jpg" width="18%"> <img src="graphics/grid_6.jpg" width="18%"> <img src="graphics/grid_7.jpg" width="18%"> <img src="graphics/grid_8.jpg" width="18%"> <img src="graphics/grid_9.jpg" width="18%"> <img src="graphics/grid_10.jpg" width="18%"> <img src="graphics/grid_11.jpg" width="18%"> <img src="graphics/grid_12.jpg" width="18%"> <img src="graphics/grid_13.jpg" width="18%"> <img src="graphics/grid_14.jpg" width="18%">

High accuracy classes predicted by Self-Classifier on ImageNet validation set. Images are sampled randomly from each predicted class. Note that the predicted classes capture a large variety of different backgrounds and viewpoints.

To reproduce qualitative examples, run:

    sbatch ./scripts/cls_eval.sh

License

See the LICENSE file for more details.