Home

Awesome

Decision Transformer in JAX and Haiku

Reproduction of 'Decision Transformer: Reinforcement Learning via Sequence Modeling' in JAX and Haiku, based on the paper at https://arxiv.org/abs/2106.01345.

<img width="1000" height="auto" alt="decision-transformer-jax" src="https://user-images.githubusercontent.com/85018688/160403605-fc54ce19-c794-452c-88a0-d27f7943297c.png">

Result

Atari

<img src="imgs/atari_result_1.png" width="800" height="auto"> <img src="https://user-images.githubusercontent.com/85018688/160198790-df5b7724-7436-41e5-a1b7-57123ec538e7.png" width="800" height="auto"> ~2x faster training speed, while achieving the evaluation performance comparable to the original implementation.

Usage

Setup

Alternatively, you can set up the project using auto-generated requirements-cpu.txt or requirement-gpu.txt: e.g. pip install -r requirements-gpu.txt (tested in python=3.8, cudatoolkit=11.1, cudnn=8.2)

Training

Run cd dt_jax && bash run.sh to train the model.

Evaluating pre-trained model

Testing

Run python -m unittest discover -s dt_jax -p '*_test.py'

Author

Yun-hyeok Kwak(yunhyeok.kwak@gmail.com)

Credits

License

MIT License

Copyright (c) 2022 Yun-hyeok Kwak