Home

Awesome

Contrastive Learning Inverts the Data Generating Process [ICML 2021]

Official code to reproduce the results and data presented in the paper Contrastive Learning Inverts the Data Generating Process.

<p align="center"> <img src="https://brendel-group.github.io/cl-ica/img/overview_compressed.svg" alt="3DIdent dataset example images" /> </p>

Experiments

To reproduce the disentanglement results for the MLP mixing, use the main_mlp.py script. For the experiments on KITTI Masks use the main_kitti.py script. For those on 3DIdent, use main_3dident.py.

MLP Mixing

> python main_mlp.py --help
usage: main_mlp.py
       [-h] [--sphere-r SPHERE_R] [--box-min BOX_MIN] [--box-max BOX_MAX]
       [--sphere-norm] [--box-norm] [--only-supervised] [--only-unsupervised]
       [--more-unsupervised MORE_UNSUPERVISED] [--save-dir SAVE_DIR]
       [--num-eval-batches NUM_EVAL_BATCHES] [--rej-mult REJ_MULT]
       [--seed SEED] [--act-fct ACT_FCT] [--c-param C_PARAM]
       [--m-param M_PARAM] [--tau TAU] [--n-mixing-layer N_MIXING_LAYER]
       [--n N] [--space-type {box,sphere,unbounded}] [--m-p M_P] [--c-p C_P]
       [--lr LR] [--p P] [--batch-size BATCH_SIZE] [--n-log-steps N_LOG_STEPS]
       [--n-steps N_STEPS] [--resume-training]

Disentanglement with InfoNCE/Contrastive Learning - MLP Mixing

optional arguments:
  -h, --help            show this help message and exit
  --sphere-r SPHERE_R
  --box-min BOX_MIN     For box normalization only. Minimal value of box.
  --box-max BOX_MAX     For box normalization only. Maximal value of box.
  --sphere-norm         Normalize output to a sphere.
  --box-norm            Normalize output to a box.
  --only-supervised     Only train supervised model.
  --only-unsupervised   Only train unsupervised model.
  --more-unsupervised MORE_UNSUPERVISED
                        How many more steps to do for unsupervised compared to
                        supervised training.
  --save-dir SAVE_DIR
  --num-eval-batches NUM_EVAL_BATCHES
                        Number of batches to average evaluation performance at
                        the end.
  --rej-mult REJ_MULT   Memory/CPU trade-off factor for rejection resampling.
  --seed SEED
  --act-fct ACT_FCT     Activation function in mixing network g.
  --c-param C_PARAM     Concentration parameter of the conditional
                        distribution.
  --m-param M_PARAM     Additional parameter for the marginal (only relevant
                        if it is not uniform).
  --tau TAU
  --n-mixing-layer N_MIXING_LAYER
                        Number of layers in nonlinear mixing network g.
  --n N                 Dimensionality of the latents.
  --space-type {box,sphere,unbounded}
  --m-p M_P             Type of ground-truth marginal distribution. p=0 means
                        uniform; all other p values correspond to (projected)
                        Lp Exponential
  --c-p C_P             Exponent of ground-truth Lp Exponential distribution.
  --lr LR
  --p P                 Exponent of the assumed model Lp Exponential
                        distribution.
  --batch-size BATCH_SIZE
  --n-log-steps N_LOG_STEPS
  --n-steps N_STEPS
  --resume-training

KITTI Masks

>python main_kitti.py --help
usage: main_kitti.py [-h] [--box-norm BOX_NORM] [--p P] [--experiment-dir EXPERIMENT_DIR] [--evaluate] [--specify SPECIFY] [--random-search] [--random-seeds] [--seed SEED] [--beta BETA] [--gamma GAMMA]
                     [--rate-prior RATE_PRIOR] [--data-distribution DATA_DISTRIBUTION] [--rate-data RATE_DATA] [--data-k DATA_K] [--betavae] [--search-beta] [--output-dir OUTPUT_DIR] [--log-dir LOG_DIR]
                     [--ckpt-dir CKPT_DIR] [--max-iter MAX_ITER] [--dataset DATASET] [--batch-size BATCH_SIZE] [--num-workers NUM_WORKERS] [--image-size IMAGE_SIZE] [--use-writer] [--z-dim Z_DIM] [--lr LR]
                     [--beta1 BETA1] [--beta2 BETA2] [--ckpt-name CKPT_NAME] [--log-step LOG_STEP] [--save-step SAVE_STEP] [--kitti-max-delta-t KITTI_MAX_DELTA_T] [--natural-discrete] [--verbose] [--cuda]
                     [--num_runs NUM_RUNS]

Disentanglement with InfoNCE/Contrastive Learning - KITTI Masks

optional arguments:
  -h, --help            show this help message and exit
  --box-norm BOX_NORM
  --p P
  --experiment-dir EXPERIMENT_DIR
                        specify path
  --evaluate            evaluate instead of train
  --specify SPECIFY     use argument to only compute a subset of metrics
  --random-search       whether to random search for params
  --random-seeds        whether to go over random seeds with UDR params
  --seed SEED           random seed
  --beta BETA           weight for kl to normal
  --gamma GAMMA         weight for kl to laplace
  --rate-prior RATE_PRIOR
                        rate (or inverse scale) for prior laplace (larger -> sparser).
  --data-distribution DATA_DISTRIBUTION
                        (laplace, uniform)
  --rate-data RATE_DATA
                        rate (or inverse scale) for data laplace (larger -> sparser). (-1 = rand).
  --data-k DATA_K       k for data uniform (-1 = rand).
  --betavae             whether to do standard betavae training (gamma=0)
  --search-beta         whether to do rand search over beta
  --output-dir OUTPUT_DIR
                        output directory
  --log-dir LOG_DIR     log directory
  --ckpt-dir CKPT_DIR   checkpoint directory
  --max-iter MAX_ITER   maximum training iteration
  --dataset DATASET     dataset name (dsprites, cars3d,smallnorb, shapes3d, mpi3d, kittimasks, natural
  --batch-size BATCH_SIZE
                        batch size
  --num-workers NUM_WORKERS
                        dataloader num_workers
  --image-size IMAGE_SIZE
                        image size. now only (64,64) is supported
  --use-writer          whether to use a log writer
  --z-dim Z_DIM         dimension of the representation z
  --lr LR               learning rate
  --beta1 BETA1         Adam optimizer beta1
  --beta2 BETA2         Adam optimizer beta2
  --ckpt-name CKPT_NAME
                        load previous checkpoint. insert checkpoint filename
  --log-step LOG_STEP   numer of iterations after which data is logged
  --save-step SAVE_STEP
                        number of iterations after which a checkpoint is saved
  --kitti-max-delta-t KITTI_MAX_DELTA_T
                        max t difference between frames sampled from kitti data loader.
  --natural-discrete    discretize natural sprites
  --verbose             for evaluation
  --cuda
  --num_runs NUM_RUNS   when searching over seeds, do 10

3DIdent

>python main_3dident.py --help
usage: main_3dident.py [-h] [--batch-size BATCH_SIZE] [--n-eval-samples N_EVAL_SAMPLES] [--lr LR] [--optimizer {adam,sgd}] [--iterations ITERATIONS]
                                                                   [--n-log-steps N_LOG_STEPS] [--load-model LOAD_MODEL] [--save-model SAVE_MODEL] [--save-every SAVE_EVERY] [--no-cuda] [--position-only]
                                                                   [--rotation-and-color-only] [--rotation-only] [--color-only] [--no-spotlight-position] [--no-spotlight-color] [--no-spotlight]
                                                                   [--non-periodic-rotation-and-color] [--dummy-mixing] [--identity-solution] [--identity-mixing-and-solution]
                                                                   [--approximate-dataset-nn-search] --offline-dataset OFFLINE_DATASET [--faiss-omp-threads FAISS_OMP_THREADS]
                                                                   [--box-constraint {None,fix,learnable}] [--sphere-constraint {None,fix,learnable}] [--workers WORKERS]
                                                                   [--mode {supervised,unsupervised,test}] [--supervised-loss {mse,r2}] [--unsupervised-loss {l1,l2,l3,vmf}]
                                                                   [--non-periodical-conditional {l1,l2,l3}] [--sigma SIGMA] [--encoder {rn18,rn50,rn101,rn151}]

Disentanglement with InfoNCE/Contrastive Learning - 3DIdent

optional arguments:
  -h, --help            show this help message and exit
  --batch-size BATCH_SIZE
  --n-eval-samples N_EVAL_SAMPLES
  --lr LR
  --optimizer {adam,sgd}
  --iterations ITERATIONS
                        How long to train the model
  --n-log-steps N_LOG_STEPS
                        How often to calculate scores and print them
  --load-model LOAD_MODEL
                        Path from where to load the model
  --save-model SAVE_MODEL
                        Path where to save the model
  --save-every SAVE_EVERY
                        After how many steps to save the model (will always be saved at the end)
  --no-cuda
  --position-only
  --rotation-and-color-only
  --rotation-only
  --color-only
  --no-spotlight-position
  --no-spotlight-color
  --no-spotlight
  --non-periodic-rotation-and-color
  --dummy-mixing
  --identity-solution
  --identity-mixing-and-solution
  --approximate-dataset-nn-search
  --offline-dataset OFFLINE_DATASET
  --faiss-omp-threads FAISS_OMP_THREADS
  --box-constraint {None,fix,learnable}
  --sphere-constraint {None,fix,learnable}
  --workers WORKERS     Number of workers to use (0=#cpus)
  --mode {supervised,unsupervised,test}
  --supervised-loss {mse,r2}
  --unsupervised-loss {l1,l2,l3,vmf}
  --non-periodical-conditional {l1,l2,l3}
  --sigma SIGMA         Sigma of the conditional distribution (for vMF: 1/kappa)
  --encoder {rn18,rn50,rn101,rn151}

3DIdent Dataset

We introduce 3DIdent, a dataset with hallmarks of natural environments (shadows, different lighting conditions, 3D rotations, etc.).

<p align="center"> <img src="https://brendel-group.github.io/cl-ica/img/3ddis.svg" alt="3DIdent dataset example images" /> </p>

You can access the full dataset here. The training and test datasets consists of 250000 and 25000 samples, respectively. To load, you can use the ThreeDIdentDataset class defined in datasets/threedident_dataset.py.

BibTeX

If you find our analysis helpful, please cite our pre-print:

@article{zimmermann2021cl,
  author = {
    Zimmermann, Roland S. and
    Sharma, Yash and
    Schneider, Steffen and
    Bethge, Matthias and
    Brendel, Wieland
  },
  title = {
    Contrastive Learning Inverts
    the Data Generating Process
  },
  booktitle = {Proceedings of the 38th International Conference on Machine Learning,
    {ICML} 2021, 18-24 July 2021, Virtual Event},
  series = {Proceedings of Machine Learning Research},
  volume = {139},
  pages = {12979--12990},
  publisher = {{PMLR}},
  year = {2021},
  url = {http://proceedings.mlr.press/v139/zimmermann21a.html},
}