Home

Awesome

DeepDPM: Deep Clustering With An Unknown Number of Clusters

This repo contains the official implementation of our CVPR 2022 paper:

DeepDPM: Deep Clustering With An Unknown Number of Clusters

Meitar Ronen, Shahaf Finder and Oren Freifeld.

DeepDPM clustering example on 2D data.<br /> On the left: DeepDPM's predicted clusters' assignments, centers and covariances. On the right: Clusters colored by the GT labels, and the net's decision boundary. <br>

<p align="center"> <img src="clustering_example.gif" width="750" height="600"> </p>

Examples of the clusters found by DeepDPM on the ImageNet Dataset:

Examples of the clusters found by DeepDPM on the ImageNet dataset

Table of Contents
  1. Introduction
  2. Installation
  3. Training
  4. Inference
  5. Citation

Introduction

DeepDPM is a nonparametric deep-clustering method which unlike most deep clustering methods, does not require knowing the number of clusters, K; rather, it infers it as a part of the overall learning. Using a split/merge framework to change the clusters number adaptively and a novel loss, our proposed method outperforms existing (both classical and deep) nonparametric methods.

While the few existing deep nonparametric methods lack scalability, we show ours by being the first such method that reports its performance on ImageNet.

Installation

The code runs with Pytorch version 3.9. Assuming Anaconda, the virtual environment can be installed using:

conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
conda install -c conda-forge pytorch-lightning=1.2.10
conda install -c conda-forge umap-learn
conda install -c conda-forge neptune-client
pip install kmeans-pytorch
conda install psutil numpy pandas matplotlib scikit-learn scipy seaborn tqdm joblib

See the requirements.txt file for an overview of the packages in the environment we used to produce our results.

Training

Setup

Datasets and embeddings

When training on raw data (e.g., on MNIST, Reuters10k) the data for MNIST will be automatically downloaded to the "data" directory. For reuters10k, the user needs to download the dataset independently (available online) into the "data" directory.

Logging

To run the following with logging enabled, edit DeepDPM.py and DeepDPM_alternations.py and insert your neptune token and project path. Alternatively, run the following script with the --offline flag to skip logging. Evaluation metrics will be printed at the end of the training in both cases.

Training models

We provide two models which can be used for clustering: DeepDPM which clusters embedded data and DeepDPM_alternations which alternates between feature learning using an AE and clustering using DeepDPM.

  1. Key hyperparameters:

Please also note the NIIW hyperparameters and the guidelines on how to choose them as described in the supplementary material.

  1. Training examples:
  1. Training on custom datasets: DeepDPM is desinged to cluster data in the feature space. For dimensionality reduction, we suggest using UMAP, an Autoencoder, or off-the-shelf unsupervised feature extractors like MoCO, SimCLR, swav, etc. If the input data is relatively low dimensional (e.g. <= 128D), it is possible to train on the raw data.

To load custom data, create a directory that contains two files: train_data.pt and test_data.pt, a tensor for the train and test data respectively. DeepDPM would automatically load them. If you have labels you wish to load for evaluation, please use the --use_labels_for_eval flag.

Note that the saved models in this repo are per dataset, and in most of the cases specific to it. Thus, it is not recommended to use for custom data.

Inference

For loading a pretrained model from a saved checkpoint, and for an inference example, see: scripts\DeepDPM_load_from_checkpoint.py

Citation

For any questions: meitarr@post.bgu.ac.il

Contributions, feature requests, suggestion etc. are welcomed.

If you use this code for your work, please cite the following:

@inproceedings{Ronen:CVPR:2022:DeepDPM,
  title={DeepDPM: Deep Clustering With An Unknown Number of Clusters},
  author={Ronen, Meitar and Finder, Shahaf E. and  Freifeld, Oren},
  booktitle={Conference on Computer Vision and Pattern Recognition},
  year={2022}
}