Home

Awesome

DMBP: Diffusion Model-Based Predictor for robust offline reinforcement learning against state observation perturbations

This is the implementation of Decision Model-Based Predictor (DMBP) and the reproduced baseline algorithms (including BCQ, CQL, TD3+BC, RORL, and Diffusion QL).

Note that this code implementation is only used for double-blind review purpose. Click here for official implementation.

GOOD NEWS!

Our paper has been accepted by ICLR2024 for poster presentation! Click here if you would like to refer to the reviews on our paper.

<!--- A de-anonymized official implementation of DMBP will be presented at [here](https://github.com/zhyang2226/DMBP) and this website will be closed when the new one is well-established. We will add the visualization module in the final version. Please star our new web, create issues there if you need. -->

If you find DMBP helpful for your work, please cite:

@inproceedings{
yang2024dmbp,
title={{DMBP}: Diffusion model based predictor for robust offline reinforcement learning against state observation perturbations},
author={Yang, Zhihe and Xu, Yunjian},
booktitle={The Twelfth International Conference on Learning Representations},
year={2024}
}

Introduction

A major challenge for the real-world application of offline RL stems from the robustness against state observation perturbations, e.g., as a result of sensor errors or adversarial attacks. Unlike online robust RL, agents cannot be adversarially trained in the offline setting.

<div style="text-align: center;"> <img src="present/intro.png" width = "600" height= "230" align=center > </div>

Our proposed Diffusion Model-Based Predictor (DMBP) is a new framework that recovers the actual states with conditional diffusion models for state-based RL tasks. Our derived non-Markovian training objective reduces the error accumulations, which are commonly observed for model-based methods in state-based RL tasks.

Follow we present the visualization of the denoising effect of DMBP with Diffusion QL (trained on hopper-medium-replay-v2). The observation is perturbed with Gaussian distributed random noise with std of 0.10.

<img src="present/Hopper_medium_replay.gif" width = "700" height = "700" alt="DMBP_Visualization" align=center />

Requirement

Our experiment is performed on D4RL benchmark environments and datasets (click here). Please install the Mujoco Version 2.1 (click here) before getting start. See requirements.txt for detailed environment set up.

Training

Baseline Algorithms Training

Before training DMBP, train a baseline offline RL algorithm at first:

python -m scripts.train_baseline --algo [ALGORITHM_NAME] --env_name [ENV_NAME] --dataset [DATASET_NAME]

DMBP Training

DMBP utilizes the trajectory datasets for training. Download the datasets of the corresponding domain through

python -m scripts.download_datasets --domain [DOMAIN_NAME]

Then DMBP can be trained through:

python -m scripts.train_DMBP --task [TASK_NAME] --algo [ALGORITHM_NAME] --env_name [ENV_NAME] --dataset [DATASET_NAME]

where the previously trained baseline algorithms are only used for training-process evaluation.

Evaluation

Robustness against noised state observations

To evaluate the baseline algorithm against different attacks on state observations, run:

python -m evaluations.eval_baseline_noise --noise_type [ATTACK_METHOD] --algo [ALGORITHM_NAME] --env_name [ENV_NAME] --dataset [DATASET_NAME]

Then, the evaluation on the corresponding baseline algorithm strengthed by DMBP can be conducted through:

python -m evaluations.eval_DMBP_noise --noise_type [ATTACK_METHOD] --algo [ALGORITHM_NAME] --env_name [ENV_NAME] --dataset [DATASET_NAME]

Robustness against incomplete state observations with unobserved dimensions

Similarly, to evaluate the baseline algorithm against incomplete state observations, run:

python -m evaluations.eval_baseline_mask --mask_dim [MASKED_DIM] --algo [ALGORITHM_NAME] --env_name [ENV_NAME] --dataset [DATASET_NAME]

Then, the evaluation on DMBP strenghted baseline algorithm can be conducted through:

python -m evaluations.eval_DMBP_mask --mask_dim [MASKED_DIM] --algo [ALGORITHM_NAME] --env_name [ENV_NAME] --dataset [DATASET_NAME]