Awesome
Semantic Segmentation of Satellite Images using Tensorflow 2.0
This repo details the steps carried out in order to perform a Semantic Segmentation task on Satellite and/or Aerial images (aka tiles). A Tensorflow 2.0 deep learning model is trained using the ISPRS dataset. The data set contains 38 patches (of 6000x6000 pixels), each consisting of a true orthophoto (TOP) extracted from a larger TOP mosaic.
Each tile is paired with a reference segmentation mask depicting the 6 classes with different colors (see below).
Development Environment
Tools and libraries:
- Python 3.5
- Imageio 2.6
- Deep Learning Libraries:
- Low-Level API: Tensorflow 2.0 (with eager_execution enabled)
- High-Level API: Keras 2.2
- Input pipeline API: Tensorflow.data
- Monitoring API: TensorBoard
Infrastructure:
- 16-Core, 64GB RAM
- Nvidia 16GB GPU (Tesla P100)
- VM instance on GCP
Patch extraction and data augmentation using tf.data input pipeline
Due to the size of the tiles (600x6000 pixels), it is not possible to feed them directly to the Tensorflow model which has an image input size limited to 256x256 pixels. Thus it is crucial to build an efficient and flexible input pipeline that reads the tile file, extracts smaller patches, performs data augmentation techniques while being fast enough to avoid data starvation of the model sitting on the GPU during the training phase. Fortunately, Tensorflow's tf.data allows the building of such a pipeline. The tile and its corresponding reference mask are processed in parallel and the produced smaller patches are like shown in the following grid:
Tensorflow model architecture
The model is based on U-Net convolutional neural network that was enhanced using skip connections and residual blocks borrowed from the Residual Neural Networks that help enhance the flow of the gradient during the backpropagation step. Keras functional API was used to implement the model.
Model training
We experimented with several loss functions based on recent A.I literature.
- The classical Sparse Categorical Cross Entropy.
- The Categorical focal loss helpful when we have imbalanced target classes.
- A custom-made Jaccard Distance loss function as we look to maximize the Intersection over Union (IoU).
In addition, we adopted the learning rate finder to spot the optimum learning-rate for the a choosen loss function. The finding process produces the following loss curve showing the learning rate sweet spot that should be picked (right before global minimum) for optimum training.
Once the optimum learning rate is found, the training is performed using the one-cycle policy training strategy. The curves below depict the evolution of the learning rate and the SGD momentum during training.
Naturally, during training, we monitor the performance metrics: Accuracy, IoU, and the loss function as shown below. The training is halted thanks to the Early Stopping strategy once the performance metrics stagnate.
Best Model Performance Metrics
The model performance measured on the validation dataset is quite amazing especially on the Building
, Road
and Car
classes (IoU > 0.8). Below are the Confusion Matrix and the Per-class IoU metrics along with some reference visuals for the IoU metric.
Tile Prediction using Test-Time Augmentation
Applying the inference pipeline to a new tile of the same size (6000x6000) could be slow if we loop through the tile to extract the patches, make a batch prediction, and stitch them together the patches to reconstruct the tile. Fortunately, can we perform such inference without any loop thanks to a clever tile reconstruction trick using Tensorflow's tf.scatter_nd. Inference time on tile is reduced from minutes to seconds.
In addition, once we performed an inference on a tile, Test-time augmentation technique enhances by several points the prediction quality as shown below:
-
Without test-time augmentation
-
With test-time augmentation