Home

Awesome

Auto-Lambda

This repository contains the source code of Auto-Lambda and baselines from the paper, Auto-Lambda: Disentangling Dynamic Task Relationships.

We encourage readers to check out our project page, including more interesting discussions and insights which are not covered in our technical paper.

Multi-task Methods

We implemented all weighting and gradient-based baselines presented in the paper for computer vision tasks: Dense Prediction Tasks (for NYUv2 and CityScapes) and Multi-domain Classification Tasks (for CIFAR-100).

Specifically, we have covered the implementation of these following multi-task optimisation methods:

Weighting-based:

Gradient-based:

Note: Applying a combination of both weighting and gradient-based methods can further improve performance.

Datasets

We applied the same data pre-processing following our previous project: MTAN which experimented on:

Note: We have included a new task: Part Segmentation for CityScapes dataset. Please install the pip install panoptic_parts for CityScapes experiments. The pre-processing file for CityScapes has also been included in the dataset folder.

Experiments

All experiments were written in PyTorch 1.7 and can be trained with different flags (hyper-parameters) when running each training script. We briefly introduce some important flags below.

Flag NameUsageComments
networkchoose multi-task network: split, mtanboth architectures are based on ResNet-50; only available in dense prediction tasks
datasetchoose dataset: nyuv2, cityscapesonly available in dense prediction tasks
weightchoose weighting-based method: equal, uncert, dwa, autolonly autol will behave differently when set to different primary tasks
grad_methodchoose gradient-based method: graddrop, pcgrad, cagradweight and grad_method can be applied together
taskchoose primary tasks: seg, depth, normal for NYUv2, seg, part_seg, disp for CityScapes, all: a combination of all standard 3 tasksonly available in dense prediction tasks
with_noisetoggle on to add noise prediction task for training (to evaluate robustness in auxiliary learning setting)only available in dense prediction tasks
subset_idchoose domain ID for CIFAR-100, choose -1 for the multi-task learning settingonly available in CIFAR-100 tasks
autol_initinitialisation of Auto-Lambda, default 0.1only available when applying Auto-Lambda
autol_lrlearning rate of Auto-Lambda, default 1e-4 for NYUv2 and 3e-5 for CityScapesonly available when applying Auto-Lambda

Training Auto-Lambda in Multi-task / Auxiliary Learning Mode:

python trainer_dense.py --dataset [nyuv2, cityscapes] --task [PRIMARY_TASK] --weight autol --gpu 0   # for NYUv2 or CityScapes dataset
python trainer_cifar.py --subset_id [PRIMARY_DOMAIN_ID] --weight autol --gpu 0   # for CIFAR-100 dataset

Training in Single-task Learning Mode:

python trainer_dense_single.py --dataset [nyuv2, cityscapes] --task [PRIMARY_TASK]  --gpu 0   # for NYUv2 or CityScapes dataset
python trainer_cifar_single.py --subset_id [PRIMARY_DOMAIN_ID] --gpu 0   # for CIFAR-100 dataset

Note: All experiments in the original paper were trained from scratch without pre-training.

Benchmark

For standard 3 tasks in NYUv2 (without noise prediction task) in the multi-task learning setting with Split architecture, please follow the results below.

MethodTypeSem. Seg. (mIOU)Depth (aErr.)Normal (mDist.)Delta MTL
Single-43.3752.2422.40-
EqualW44.6443.3224.48+3.57%
DWAW45.1443.0624.17+4.58%
GradDropG45.3943.2324.18+4.65%
PCGradG45.1542.3824.13+5.09%
UncertaintyW45.9841.2624.09+6.50%
CAGradG46.1441.9123.52+7.05%
Auto-LambdaW47.1740.9723.68+8.21%
Auto-Lambda + CAGradW + G48.2639.8222.81+11.07%

Note: The results were averaged across three random seeds. You should expect the error range less than +/-1%.

Citation

If you found this code/work to be useful in your own research, please considering citing the following:

@article{liu2022auto_lambda,
    title={Auto-Lambda: Disentangling Dynamic Task Relationships},
    author={Liu, Shikun and James, Stephen and Davison, Andrew J and Johns, Edward},
    journal={Transactions on Machine Learning Research},
    year={2022}
}

Acknowledgement

We would like to thank @Cranial-XIX for his clean implementation for gradient-based optimisation methods.

Contact

If you have any questions, please contact sk.lorenmt@gmail.com.