Home

Awesome

Pseudo-Label Guided Image Synthesis for Semi-Supervised COVID-19 Pneumonia Infection Segmentation.

Implementation of Pseudo-Label Guided Image Synthesis for Semi-Supervised COVID-19 Pneumonia Infection Segmentation.

<p align="center"> <img src="framework.png" align="center" width="50%"> </p>

Implementation

1. Installation

pytorch==1.9.0

2. Datset Preparation

├── COVID249
│   ├── NII (Original dataset in NIFTI)
│   ├── PNG (Pre-processed dataset in PNG)
│   ├── train_0.1_l.xlsx (datasplit for 10% setting)
│   ├── train_0.1_u.xlsx (datasplit for 10% setting)
│   ├── train_0.2_l.xlsx (datasplit for 20% setting)
│   ├── train_0.2_u.xlsx (datasplit for 20% setting)
│   ├── train_0.3_l.xlsx (datasplit for 30% setting)
│   ├── train_0.3_u.xlsx (datasplit for 30% setting)
│   ├── test_slice.xlsx (datasplit for testing)
│   ├── val_slice.xlsx (datasplit for validation)
├── MOS1000
│   ├── NII (Original dataset in NIFTI)
│   ├── PNG (Pre-processed dataset in PNG)
│   ├── train_slice_label.xlsx (datasplit)
│   ├── train_slice_unlabel.xlsx (datasplit)
│   ├── test_slice.xlsx (datasplit for testing)
│   ├── val_slice.xlsx (datasplit for validation)

3. Training Our Models

python train_SACPS.py
python train_SAST.py

4. Training Other Models

We have provided a template for training other models, where we have implemented the dataloader, optimizer, etc. The core codes are shown as below:

for epoch in range(max_epoch):
    print("Start epoch ", epoch+1, "!")

    tbar = tqdm(range(len(unlabeled_dataloader)), ncols=70)
    labeled_dataloader_iter = iter(labeled_dataloader)
    unlabeled_dataloader_iter = iter(unlabeled_dataloader)

    for batch_idx in tbar:
        try:
            input_l, target_l, file_name_l , lung_l = labeled_dataloader_iter.next()
        except StopIteration:
            labeled_dataloader_iter = iter(labeled_dataloader)
            input_l, target_l, file_name_l , lung_l = labeled_dataloader_iter.next()

        # load data
        input_ul, target_ul, file_name_ul , lung_ul = unlabeled_dataloader_iter.next()
        input_ul, target_ul, lung_ul = input_ul.cuda(non_blocking=True), target_ul.cuda(non_blocking=True), lung_ul.cuda(non_blocking=True)
        input_l, target_l, lung_l = input_l.cuda(non_blocking=True), target_l.cuda(non_blocking=True), lung_l.cuda(non_blocking=True)


        # Add impelmentation here: the training process
        #-------------------------------------------------------------
        #*************************************************************
        #-------------------------------------------------------------


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr_

        iter_num = iter_num + 1
        writer.add_scalar('info/lr', lr_, iter_num)
        writer.add_scalar('info/total_loss', loss, iter_num)
        logging.info('iteration %d : loss : %f' % (iter_num, loss.item()))
writer.close()

5. Testing

python segment_test.py
def test(args, snapshot_path):
    model = net_factory(net_type=args.model)
    model.load_state_dict(torch.load(args.model_path))
    model.eval()

    nsd, dice = get_model_metric(args = args, model = model, snapshot_path=snapshot_path, model_name='model', mode='test')
    print('nsd : %f dice : %f ' % (nsd, dice))

Suplementary information

  1. Statistics of the datasets.

Descriptive statistics, including x-, y- and z-spacing, of both datasets are shown as follow.

<p align="center"> <img src="dataset.png" align="center" width='100%' height="100%"> </p>
  1. Links for competing methods.

Citation

If you find this repository useful for your research, please cite the following:

@ARTICLE{9931157,
  author={Lyu, Fei and Ye, Mang and Carlsen, Jonathan Frederik and Erleben, Kenny and Darkner, Sune and Yuen, Pong C.},
  journal={IEEE Transactions on Medical Imaging}, 
  title={Pseudo-Label Guided Image Synthesis for Semi-Supervised COVID-19 Pneumonia Infection Segmentation}, 
  year={2022},
  volume={},
  number={},
  pages={1-1},
  doi={10.1109/TMI.2022.3217501}}

Acknowledgments

We thank Luo, Xiangde for sharing his codes, our code borrows heavily from https://github.com/HiLab-git/SSL4MIS.