Home

Awesome

Cross-stitch-Networks-for-Multi-task-Learning

This project is a TensorFlow implementation of a Multi Task Learning method described in the paper Cross-stitch Networks for Multi-task Learning.

Arguments

Dataset

Fashion-MNIST

Fashion-MNIST is a dataset of Zalando's article images, consisting of a training set of 60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image, associated with a label from 10 classes:

LabelDescriptionLabelDescription
0T-shirt/top5Sandal
1Trouser6Shirt
2Pullover7Sneaker
3Dress8Bag
4Coat9Ankle boot

For multi task learning, I created another label for each image, which is based on the original labels:

LabelOriginal LabelsDescription
05, 7, 9Shoes
13, 6, 8For Women
20, 1, 2, 4Other

The network will train these two classifiers together.

Network

Without task sharing

As a baseline, a network without cross stitch is built, which simply concats two convolutional neural networks side by side. Each network is for one task, although their parameters are not shared. The final loss function is the sum of two loss functions of sub networks.

Here is an overview of this structure:

Network structure without task sharing

Both sub convolutional neural networks have the same architecture:

LayerOutput sizefilter size / stride
conv128x28x323x3 / 1
pool_114x14x322x2 / 2
conv214x14x643x3 / 1
pool_27x7x642x2 / 2
fc_31024
output10 or 3 depends on task

With Cross Stitch

Cross Stitch is a transformation applied between layers, it describes the relationship between different tasks with a linear combination of their activations.

linear combination

The network should learn the relationship by itself, in comparison with manually tuning the shared network structure, this end-to-end approach works better.

Here is an overview of this structure:

Network strcture with Cross Stitch

The convolutional sub networks have the same architecture as above. As in paper suggested the cross stitch units are only added after Pool layers and Fully Connected layers.

Training

Evaluation

The overall accuracy is calculated by averaging the accuracies of all sub tasks.

With cross stitch transformation it gets more than 1% improvement on test dataset.

Orange: without sharing. Blue: with cross stitch. test accuracy total loss

Result

For Fashion-MNIST new labels are created based on the original labels, so two classification tasks are highly related. I also used this technique to build a gender-age classifier with VGGFace2 dataset, which labels are more independent. In both tests cross stitch improves the accuracy. Although this project only trained with two tasks but it can be extended to more tasks easily.

I didn't pretrain the sub networks as in paper suggested and I also used a different initialization strategy. A better result might be found with more tuning.