Home

Awesome

U-Net for Semantic Segmentation

Overview

This repo has the code to train and test U-Net for Semantic Segmentation task over images. Contains both conventional as well as Federated Traning using FedAvg algorithm in Flower framework.

Getting Datasets

sh getAllData.sh

Federated Testing commands

Inference Testing

python inference.py --data data/CityScape-Dataset/C1-Vehicle_NoPeople-65 --img data/CityScape-Dataset/C1-Vehicle_NoPeople-65/Image/ulm_000009_000019_leftImg8bit.png --meta data/CityScape-Dataset --checkpoint saved_models/unet_epoch_0_1.67928.pt --ind 0

Federated server

python server.py > server.txt

Federated Clients (CityScape)

python client_kd.py --data data/CityScape-Dataset/Train/C1-Vehicle_NoPeople-65 --test data/CityScape-Dataset/Test --meta data/CityScape-Dataset --num_epochs 50 --loss crossentropy --name client1 
python client_kd.py --data data/CityScape-Dataset/Train/C2-People_NoVehicle-22 --test data/CityScape-Dataset/Test --meta data/CityScape-Dataset --num_epochs 50 --loss crossentropy --name client2
python client_kd.py --data data/CityScape-Dataset/Train/C3-NoVehicle_NoPeople-11 --test data/CityScape-Dataset/Test --meta data/CityScape-Dataset --num_epochs 50 --loss crossentropy --name client3

Federated Clients (Chennai)

python client_kd.py --data data/Chennai-Dataset/Train/D1 --meta data/Chennai-Dataset --test data/Chennai-Dataset/Test/T1 --num_epochs 2 --loss crossentropy --name clientCHN1 > clientCHN1.txt
python client_kd.py --data data/Chennai-Dataset/Train/D2 --meta data/Chennai-Dataset --test data/Chennai-Dataset/Test/T2 --num_epochs 2 --loss crossentropy --name clientCHN2 > clientCHN2.txt
python client_kd.py --data data/Chennai-Dataset/Train/D3 --meta data/Chennai-Dataset --test data/Chennai-Dataset/Test/T3 --num_epochs 2 --loss crossentropy --name clientCHN3 > clientCHN3.txt

Unified Testing commands

Unified Testing on CityScape Dataset

python train.py --data data/CityScape-Dataset-Unified/Train --test data/CityScape-Dataset-Unified/Test --meta data/CityScape-Dataset-Unified --num_epochs 50 --loss crossentropy --name UnifiedCSP > UnifiedCSP.txt

Unified Testing on Chennai Dataset

python train.py --data data/Chennai-Dataset-Unified/Train --meta data/Chennai-Dataset-Unified --test data/Chennai-Dataset-Unified/Test --num_epochs 50 --loss crossentropy --name UnifiedCHN > UnifiedCHN.txt

KFold Unified commands

Chennai Dataset

python train_kfold.py --data data/Chennai-Dataset-KFold/ --meta data/Chennai-Dataset-KFold/ --name ChennaiKFold --folds 5 --epochs 10 --batch 1 --loss crossentropy --model Custom_Slim_UNet > UnifiedCHNFolded.txt

Cityscape Dataset

python train_kfold.py --data data/CityScape-Dataset-KFold/ --meta data/CityScape-Dataset-KFold/ --name CityScapeKFold --folds 5 --epochs 10 --batch 1 --loss crossentropy --model Custom_Slim_UNet > UnifiedCSPFolded.txt

KFold Federated commmands

Chennai Dataset

python client_kfold.py --data data/Chennai-Federated-Dataset-KFold/C1 --meta data/Chennai-Federated-Dataset-KFold --folds 5 --epochs 10 --loss crossentropy --batch 1 --model Custom_Slim_UNet --name clientKFoldCHN1 > clientKFoldCHN1.txt
python client_kfold.py --data data/Chennai-Federated-Dataset-KFold/C2 --meta data/Chennai-Federated-Dataset-KFold --folds 5 --epochs 10 --loss crossentropy --batch 1 --model Custom_Slim_UNet --name clientKFoldCHN2 > clientKFoldCHN2.txt
python client_kfold.py --data data/Chennai-Federated-Dataset-KFold/C3 --meta data/Chennai-Federated-Dataset-KFold --folds 5 --epochs 10 --loss crossentropy --batch 1 --model Custom_Slim_UNet --name clientKFoldCHN3 > clientKFoldCHN3.txt

Cityscape Dataset

python client_kfold.py --data data/CityScape-Federated-Dataset-KFold/C1 --meta data/CityScape-Federated-Dataset-KFold --folds 5 --epochs 10 --loss crossentropy --batch 1 --model Custom_Slim_UNet --name clientKFoldCSP1 > clientKFoldCSP1.txt
python client_kfold.py --data data/CityScape-Federated-Dataset-KFold/C2 --meta data/CityScape-Federated-Dataset-KFold --folds 5 --epochs 10 --loss crossentropy --batch 1 --model Custom_Slim_UNet --name clientKFoldCSP2 > clientKFoldCSP2.txt
python client_kfold.py --data data/CityScape-Federated-Dataset-KFold/C3 --meta data/CityScape-Federated-Dataset-KFold --folds 5 --epochs 10 --loss crossentropy --batch 1 --model Custom_Slim_UNet --name clientKFoldCSP3 > clientKFoldCSP3.txt

FedUKD

  1. First download the complete cityscape dataset

  2. Then download the smaller federated datasets using

    sh getAllData.sh
    
  3. Unzip the complete dataset inside "data/" folder

  4. Put classes.json from federated dataset ("data/CityScape-Dataset/") into the root of Cityscape-BigDataSet

  5. Run centralized training with

    python train.py --data data/CityScape-BigDataSet/train --test data/CityScape-BigDataSet/test --meta data/CityScape-BigDataSet --num_epochs 50 --loss crossentropy --name CSBig
    
  6. Replace the name of the centralized weights in client_kd.py, line 145, weights are found inside "saved_models" folder Eg.

    teacher.load_state_dict(torch.load('./saved_models/CSBig_epoch_x_x.xxxxx.pt'))
    
  7. Start Server There will be no output since we're sending the output stream to a text file. If you want to see the model run, remove the "> server.txt" and manually copy paste the output to server.txt after completion.

    python server.py > server.txt
    
  8. Start 3 clients There will be no output since we're sending the output stream to a text file CLIENT 1:

    python client_kd.py --data data/CityScape-Dataset/Train/C1-Vehicle_NoPeople-65 --test data/CityScape-Dataset/Test --meta data/CityScape-Dataset --num_epochs 50 --loss crossentropy --name client1 > client1.txt
    

    CLIENT 2:

    python client_kd.py --data data/CityScape-Dataset/Train/C2-People_NoVehicle-22 --test data/CityScape-Dataset/Test --meta data/CityScape-Dataset --num_epochs 50 --loss crossentropy --name client2 > client2.txt
    

    CLIENT 3:

    python client_kd.py --data data/CityScape-Dataset/Train/C3-NoVehicle_NoPeople-11 --test data/CityScape-Dataset/Test --meta data/CityScape-Dataset --num_epochs 50 --loss crossentropy --name client3 > client3.txt
    
  9. Run plot_clients.py

    python plot_clients.py