Home

Awesome

IMPaSh

Implementation of paper [arXiv]:

"IMPaSh: A Novel Domain-shift Resistant Representation for Colorectal Cancer Tissue Classification" by Trinh Thi Le Vuong, Quoc Dang Vu, Mostafa Jahanifar, Simon Graham, Jin Tae Kwak, and Nasir Rajpoot. ECCV Workshop 2022.

<p align="center"> <img src="figures/Network.png" width="600"> </p>

Model Weights

IMPaSh's encoder and classifier weights:

Snippet of PatchShuffling Module

import numpy as np
from random import shuffle
from PIL import Image

class PatchShuffling(object):
    """
     PatchShuffling Module
    """
    def __init__(self, n_grid=3, img_size=255, crop_size=64):
        self.n_grid = n_grid
        self.img_size = img_size
        self.crop_size = crop_size
        self.grid_size = int(img_size / self.n_grid)
        self.side = self.grid_size - self.crop_size

        yy, xx = np.meshgrid(np.arange(n_grid), np.arange(n_grid))
        self.yy = np.reshape(yy * self.grid_size, (n_grid * n_grid,))
        self.xx = np.reshape(xx * self.grid_size, (n_grid * n_grid,))

        self.re_yy = np.reshape(yy * self.crop_size, (n_grid * n_grid,))
        self.re_xx = np.reshape(xx * self.crop_size, (n_grid * n_grid,))

    def __call__(self, img):
        r_x = np.random.randint(0, self.side + 1, self.n_grid * self.n_grid)
        r_y = np.random.randint(0, self.side + 1, self.n_grid * self.n_grid)
        img = np.asarray(img, np.uint8)
        crops = []
        for i in range(self.n_grid * self.n_grid):
            crops.append(img[self.xx[i] + r_x[i]: self.xx[i] + r_x[i] + self.crop_size,
                         self.yy[i] + r_y[i]: self.yy[i] + r_y[i] + self.crop_size, :])
        shuffle(crops)

        shuffling_img = np.zeros([self.crop_size*self.n_grid, self.crop_size*self.n_grid, 3], dtype='uint8')
        for i in range(self.n_grid * self.n_grid):
            shuffling_img[self.re_xx[i]: self.re_xx[i] + self.crop_size, self.re_yy[i]: self.re_yy[i] + self.crop_size] \
                = crops[i]

        return Image.fromarray(shuffling_img)

Train the self-supervised encoder and Options

python main_contrast.py \
  --method IMPaShMoCo \
  --jigsaw_stitch\
  --cosine \
  --dataset_name  k19 \
  --multiprocessing-distributed --world-size 1 --rank 0 \
  --dist-url 'tcp://127.0.0.1:23458'

Train the classifier and Options

python main_linear.py \
 --method PatchSMoco \
 --ckpt ./save/k19_IMPaSH/ckpt_epoch_200.pth\
 --aug_linear RA \
 --dataset_name  k19 \
 --keephead head \
 --multiprocessing-distributed --world-size 1 --rank 0 \
 --dist-url 'tcp://127.0.0.1:23458'

Inference on target dataset

python main_infer.py \
 --method IMPaSH \
 --ckpt ./save/k19_IMPaSH/ckpt_epoch_200.pth\
 --ckpt_class ./save/k19_IMPaSH_linear_head_True/ckpt_epoch_40.pth\
 --dataset_name  k16 \
 --keephead head \
 --multiprocessing-distributed --world-size 1 --rank 0 \
 --dist-url 'tcp://127.0.0.1:23452'