Home

Awesome

<img src="./propagate.png"></img>

Pixel-level Contrastive Learning

Implementation of Pixel-level Contrastive Learning, proposed in the paper <a href="https://arxiv.org/abs/2011.10043">"Propagate Yourself"</a>, in Pytorch. In addition to doing contrastive learning on the pixel level, the online network further passes the pixel level representations to a Pixel Propagation Module and enforces a similarity loss to the target network. They beat all previous unsupervised and supervised methods in segmentation tasks.

Install

$ pip install pixel-level-contrastive-learning

Usage

Below is an example of how you would use the framework to self-supervise training of a resnet, taking the output of layer 4 (8 x 8 'pixels').

import torch
from pixel_level_contrastive_learning import PixelCL
from torchvision import models
from tqdm import tqdm

resnet = models.resnet50(pretrained=True)

learner = PixelCL(
    resnet,
    image_size = 256,
    hidden_layer_pixel = 'layer4',  # leads to output of 8x8 feature map for pixel-level learning
    hidden_layer_instance = -2,     # leads to output for instance-level learning
    projection_size = 256,          # size of projection output, 256 was used in the paper
    projection_hidden_size = 2048,  # size of projection hidden dimension, paper used 2048
    moving_average_decay = 0.99,    # exponential moving average decay of target encoder
    ppm_num_layers = 1,             # number of layers for transform function in the pixel propagation module, 1 was optimal
    ppm_gamma = 2,                  # sharpness of the similarity in the pixel propagation module, already at optimal value of 2
    distance_thres = 0.7,           # ideal value is 0.7, as indicated in the paper, which makes the assumption of each feature map's pixel diagonal distance to be 1 (still unclear)
    similarity_temperature = 0.3,   # temperature for the cosine similarity for the pixel contrastive loss
    alpha = 1.,                      # weight of the pixel propagation loss (pixpro) vs pixel CL loss
    use_pixpro = True,               # do pixel pro instead of pixel contrast loss, defaults to pixpro, since it is the best one
    cutout_ratio_range = (0.6, 0.8)  # a random ratio is selected from this range for the random cutout
).cuda()

opt = torch.optim.Adam(learner.parameters(), lr=1e-4)

def sample_batch_images():
    return torch.randn(10, 3, 256, 256).cuda()

for _ in tqdm(range(100000)):
    images = sample_batch_images()
    loss = learner(images) # if positive pixel pairs is equal to zero, the loss is equal to the instance level loss

    opt.zero_grad()
    loss.backward()
    print(loss.item())
    opt.step()
    learner.update_moving_average() # update moving average of target encoder

# after much training, save the improved model for testing on downstream task
torch.save(resnet, 'improved-resnet.pt')

You can also return the number of positive pixel pairs on forward, for logging or other purposes

loss, positive_pairs = learner(images, return_positive_pairs = True)

Citations

@misc{xie2020propagate,
    title={Propagate Yourself: Exploring Pixel-Level Consistency for Unsupervised Visual Representation Learning}, 
    author={Zhenda Xie and Yutong Lin and Zheng Zhang and Yue Cao and Stephen Lin and Han Hu},
    year={2020},
    eprint={2011.10043},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}