Home

Awesome

Decision Transformer

Overview

Minimal code for Decision Transformer: Reinforcement Learning via Sequence Modeling for mujoco control tasks in OpenAI gym. Notable difference from official implementation are:

Open min_decision_transformer.ipynb in Google Colab Open In Colab

Results

Note: these results are mean and variance of 3 random seeds obtained after 20k updates (due to timelimits on GPU resources on colab) while the official results are obtained after 100k updates. So these numbers are not directly comparable, but they can be used as rough reference points along with their corresponding plots to measure the learning progress of the model. The variance in returns and scores should decrease as training reaches saturation.

DatasetEnvironmentDT (this repo) 20k updatesDT (official) 100k updates
MediumHalfCheetah42.18 ± 00.5942.60 ± 00.10
MediumHopper69.43 ± 27.3467.60 ± 01.00
MediumWalker75.47 ± 31.0874.00 ± 01.40

Instructions

Mujoco-py

Install mujoco-py library by following instructions on mujoco-py repo

D4RL Data

Datasets are expected to be stored in the data directory. Install the D4RL repo. Then save formatted data in the data directory by running the following script:

python3 data/download_d4rl_datasets.py

Running experiments

python3 scripts/train.py --env halfcheetah --dataset medium --device cuda
python3 scripts/test.py --env halfcheetah --dataset medium --device cpu --num_eval_ep 1 --chk_pt_name dt_halfcheetah-medium-v2_model_22-02-13-09-03-10_best.pt

The dataset needs to be specified for testing, to load the same state normalization statistics (mean and var) that is used for training. An additional --render flag can be passed to the script for rendering the test episode.

python3 scripts/plot.py --env_d4rl_name halfcheetah-medium-v2 --smoothing_window 5

Additionally --plot_avg and --save_fig flags can be passed to the script to average all values in one plot and to save the figure.

Note:

  1. If you find it difficult to install mujoco-py and d4rl then you can refer to their installation in the colab notebook
  2. Once the dataset is formatted and saved with download_d4rl_datasets.py, d4rl library is not required further for training.
  3. The evaluation is done on v3 control environments in mujoco-py so that the results are consistent with the decision transformer paper.

Citing

Please use this bibtex if you want to cite this repository in your publications:

@misc{minimal_decision_transformer,
    author = {Barhate, Nikhil},
    title = {Minimal Implementation of Decision Transformer},
    year = {2022},
    publisher = {GitHub},
    journal = {GitHub repository},
    howpublished = {\url{https://github.com/nikhilbarhate99/min-decision-transformer}},
}

References