Home

Awesome

Cross-task Attention Mechanism for Dense Multi-task Learning (DenseMTL)

This repository provides the official source code and model weights for our Cross-task Attention Mechanism for Dense Multi-task Learning paper (WACV 2023). The implementation is done using the PyTorch library.

<p align="center"> <img src="./docs/dark.png#gh-dark-mode-only" width="700"/> <img src="./docs/light.png#gh-light-mode-only" width="700"/> </p>

DenseMTL: Cross-task Attention Mechanism for Dense Multi-task Learning
Ivan Lopes<sup>1</sup>, Tuan-Hung Vu<sup>1,2</sup>, Raoul de Charette<sup>1</sup></br> <sup>1</sup> Inria, Paris, France. <sup>2</sup> Valeo.ai, Paris, France.<br>

<img src="./docs/berlin.gif" />

To cite our paper, please use:

@inproceedings{lopes2023densemtl,
  title={Cross-task Attention Mechanism for Dense Multi-task Learning},
  author={Lopes, Ivan and Vu, Tuan-Hung and de Charette, Raoul},
  booktitle={WACV},
  year={2023}
}

Table of content

Overview

DenseMTL is an cross-attention based multi-task architecture which leverages multiple attention mechanisms to extract and enrich task features. As seen in the figure above, xTAM modules each receive a pair of differing task features to better assess cross task interactions and allow for an efficient cross talk distillation.

In total, this work covers a wide range of experiments, we summarize it by:

Installation

1. Dependencies

First create a new conda environment with the required packages found in environment.yml. You can do so with the following line:

>>> conda env create -n densemtl -f environment.yml

Then activate environment densemtl using:

>>> conda activate densemtl

2. Datasets

3. Environment variables

Update configs/env_config.yml with the path to the different directories by defining:

All constants provided in this file are loaded as environment variables and accessible at runtime via os.environ. Alternatively those constants can be defined in the command line before running the project.

Running DenseMTL

1. Command Line Interface

The following are the command line inferface arguments and options:

--env-config ENV_CONFIG
  Path to file containing the environment paths, defaults to configs/env_config.yml.
--base CONFIG
  Optional path to base configuration yaml file, can be left unused if --config file contains all keys
--config CONFIG
  Path to main configuration yaml file.
--project PROJECT
  Project name for logging and used as wandb project
--resume
  Flag to resume training, this will look for last available model checkpoint from the same setup
--evaluate PATH
  Will load the model provided at the file path and perform evaluation using the config setup.
-s SEED, --seed SEED
  Seed for training and dataset.
-d, --debug
  Flag to perform single validation inference for debugging purposes.
-w, --disable-wandb
  Flag to disable Weight & Biases logging

Experiments are based off of configuration files. Overall each configuration file must follow this structure:

setup:
  name: exp-name
  model:
    └── model args
  loss:
    └── loss args
  lr:
    └── learning rates
data:
  └── data module args
training:
  └── training args
optim:
  └── optimizer args
scheduler:
  └── scheduler args

For arguments which are recurring across experiments such as data, training, optim, scheduler, we use a base configuration file that we pass to the process via the --base option. The two configuration files (provided with --base and --config) are merged together at the top level (config can overwrite base). See more details in main.py.

Environment variables can be referenced inside the configuration file by using the $ENV: prefix, eg.: path: $ENV:CITYSCAPES_DIR.

2. Experiments

To reproduce the experiments, you can run the following scripts.

3. Models

Our models on fully supervised training:

SetupSetLink
SynthiaSDsy_densemtl_SD.pkl
SynthiaSDNsy_densemtl_SDN.pkl
Virtual Kitti 2SDvk_densemtl_SD.pkl
Virtual Kitti 2SDNvk_densemtl_SDN.pkl
CityscapesSDcs_densemtl_SD.pkl
CityscapesSDNcs_densemtl_SDN.pkl

4. Evaluation

To evaluate a model, the --evaluate option can be set with a path to the state dictionnary .pkl file. This weight file will be loaded onto the model and the evaluation loop launched. Keep in mind you also need to provide a valid configuration files in order to evaluate our method with weights located in weights/vkitti2_densemtl_SD.pkl, simply run:

python main.py \
  --config=configs/vkitti2/resnet101_ours_SD.yml \
  --base=configs/vkitti2/fs_bs2.yml \
  --evaluate=weights/vkitti2densemtl.pkl

5. Visualization & Logging

By default, visualizations, losses, and metrics are logged using Weights & Biases. In case you do not wish to log your trainings and evaluations through this tool, you can disable it by using the --disable-wandb flag. In all cases, the loss values and metrics are logged via the standard output.

Checkpoints, models and configuration files are saved under the LOG_DIR directory folder. More specifically, those will be located under <LOG_DIR>/<dataset>/<config-name>/s<seed>/<timestamp>. For example you could have something like: <LOG_DIR>/vkitti2/resnet101_ours_SD/s42/2022-04-19_10-09-49 for a SD training of our method on VKITTI2 with a seed equal to 42.

Project structure

The main.py file is the entry point to perform training and evaluation on the different setups.

root
  ├── configs/    % Configuration files to run the experiments
  ├── training/   % Training loops for all settings
  ├── dataset/    % PyTorch dataset definitions as well as semantic segmentation encoding logic
  ├── models/     % Neural network modules, inference logic and method implementation
  ├── optim/      % Optimizers related code
  ├── loss/       % Loss modules for each task type includes the metric and visualization calls
  ├── metrics/    % Task metric implementations
  ├── vendor/     % Third party source code
  └── utils/      % Utility code for other parts of the code

Credit

This repository contains code taken from Valeo.ai's ADVENT, Simon Vandenhende's MTL-survey, Niantic Labs' Monodepth 2, and Lukas Hoyer's 3-Ways.

License

DenseMTL is released under the Apache 2.0 license.


↑ back to top