Home

Awesome

A Contrastive Objective for Disentangled Representations

DCoDR

A Contrastive Objective for Disentangled Representations
Joanthan Kahana and Yedid Hoshen
Official PyTorch Implementation

Abstract: Learning representations of images that are invariant to sensitive or unwanted attributes is important for many tasks including bias removal and cross domain retrieval. Here, our objective is to learn representations that are invariant to the domain (sensitive attribute) for which labels are provided, while being informative over all other image attributes, which are unlabeled. We present a new approach, proposing a new domain-wise contrastive objective for ensuring invariant representations. This objective crucially restricts negative image pairs to be drawn from the same domain, which enforces domain invariance whereas the standard contrastive objective does not. This domain-wise objective is insufficient on its own as it suffers from shortcut solutions resulting in feature suppression. We overcome this issue by a combination of a reconstruction constraint, image augmentations and initialization with pre-trained weights. Our analysis shows that the choice of augmentations is important, and that a misguided choice of augmentations can harm the invariance and informativeness objectives. In an extensive evaluation, our method convincingly outperforms the state-of-the-art in terms of representation invariance, representation informativeness, and training speed. Furthermore, we find that in some cases our method can achieve excellent results even without the reconstruction constraint, leading to a much faster and resource efficient training.

This repository is the official PyTorch implementation of A Contrastive Objective for Disentangled Representations

<a href="https://arxiv.org/abs/2203.11284" target="_blank"><img src="https://img.shields.io/badge/arXiv-2203.11284-b31b1b.svg"></a>

Usage

By default the <base-dir> directory is the main directory of the repository, although it can be changed in the code itself.

Requirements

python >= 3.7.3 cuda >= 11.1 pytorch >= 1.9.0

Downloading The ImageNet Pre-Trained Weights

Please create a directory <base-dir>/pretrained_weights and put the ImageNet pre-trained weights in it.

MocoV2 weights can be downloaded from here or from the official github page

Downloading the Pre-Processed Datasets

NOTE: you need to download the datasets first.

We provide pre-processed versions of the datasets. They are found in here.

Please put the pre-processed versions under cache/preprocess.

Pre-Processing the Datasets Yourself (Optional)

NOTE: As mentions above you can download the pre-processed versions from here.

We also supply scripts for creating the pre-processed versions.

In case you download the datasets to other locations, make sure to update the path in the beginning of the corresponding preprocessing script before running it.

The preprocessing can be applied by:

scripts/prepare_#DATASET_NAME#.py

Training

Given a preprocessed train set and test set as the scripts create,

Training a dataset can be done by running one of the attached bash scripts in the bash_scripts folder, according to the desired experiment.

To train DCoDR on smallnorb for example, simply run:

bash bash_scripts/DCoDR/smallnorb/DCoDR__smallnorb__pipeline.sh

Trained Models

We provide trained models for all of the evaluated datasets from the main experiment in the paper. Please download model .pth files as well as the config.pkl file which is needed for evaluation.

DatasetDCoDR-norecDCoDR
Cars3DDCoDR-norec Cars3dDCoDR Cars3D
SmallNorbDCoDR-norec SmallNorbDCoDR SmallNorb
CelebADCoDR-norec CelebADCoDR CelebA
Edges2ShoesDCoDR-norec Edges2ShoesDCoDR Edges2Shoes
Shapes3DDCoDR-norec Shapes3DDCoDR Shapes3D

Citation

If you find this useful, please cite our paper:

@inproceedings{kahana2022dcodr,
 author = {Kahana, Jonathan and Hoshen, Yedid},
 booktitle = {European Conference on Computer Vision (ECCV) },
 title = {A Contrastive Objective for Disentangled Representations},
 year = {2022}
}