Home

Awesome

Tensorflow python jupyter

Semantic Segmentation of Satellite Images using Tensorflow 2.0

This repo details the steps carried out in order to perform a Semantic Segmentation task on Satellite and/or Aerial images (aka tiles). A Tensorflow 2.0 deep learning model is trained using the ISPRS dataset. The data set contains 38 patches (of 6000x6000 pixels), each consisting of a true orthophoto (TOP) extracted from a larger TOP mosaic.

Each tile is paired with a reference segmentation mask depicting the 6 classes with different colors (see below). example

Development Environment

Tools and libraries:

Infrastructure:

Patch extraction and data augmentation using tf.data input pipeline

Due to the size of the tiles (600x6000 pixels), it is not possible to feed them directly to the Tensorflow model which has an image input size limited to 256x256 pixels. Thus it is crucial to build an efficient and flexible input pipeline that reads the tile file, extracts smaller patches, performs data augmentation techniques while being fast enough to avoid data starvation of the model sitting on the GPU during the training phase. Fortunately, Tensorflow's tf.data allows the building of such a pipeline. The tile and its corresponding reference mask are processed in parallel and the produced smaller patches are like shown in the following grid: example

Tensorflow model architecture

The model is based on U-Net convolutional neural network that was enhanced using skip connections and residual blocks borrowed from the Residual Neural Networks that help enhance the flow of the gradient during the backpropagation step. Keras functional API was used to implement the model. example

Model training

We experimented with several loss functions based on recent A.I literature.

In addition, we adopted the learning rate finder to spot the optimum learning-rate for the a choosen loss function. The finding process produces the following loss curve showing the learning rate sweet spot that should be picked (right before global minimum) for optimum training. example

Once the optimum learning rate is found, the training is performed using the one-cycle policy training strategy. The curves below depict the evolution of the learning rate and the SGD momentum during training. example

Naturally, during training, we monitor the performance metrics: Accuracy, IoU, and the loss function as shown below. The training is halted thanks to the Early Stopping strategy once the performance metrics stagnate. example

Best Model Performance Metrics

The model performance measured on the validation dataset is quite amazing especially on the Building, Road and Car classes (IoU > 0.8). Below are the Confusion Matrix and the Per-class IoU metrics along with some reference visuals for the IoU metric. example

Tile Prediction using Test-Time Augmentation

Applying the inference pipeline to a new tile of the same size (6000x6000) could be slow if we loop through the tile to extract the patches, make a batch prediction, and stitch them together the patches to reconstruct the tile. Fortunately, can we perform such inference without any loop thanks to a clever tile reconstruction trick using Tensorflow's tf.scatter_nd. Inference time on tile is reduced from minutes to seconds.

In addition, once we performed an inference on a tile, Test-time augmentation technique enhances by several points the prediction quality as shown below:

Implementation and Report