Awesome
Exploring Q-Learning in Pure-GPU Setting
<img src="https://img.shields.io/badge/license-Apache2.0-blue.svg">
The goal of this project is to provide very simple and lightweight scripts for Q-Learning baselines in various single-agent and multi-agent settings that can run effectively on pure-GPU environments. It follows the cleanrl philosophy of single-file scripts and is deeply inspired by purejaxrl, which aims to compile entire RL pipelines on the GPU using JAX.
The main algorithm currently supported is the Parallelised Q-Network (PQN), developed specifically to run effectively in a pure-GPU setting. The main features of PQN are:
- Simplicity: PQN is a very simple baseline, essentially an online Q-learner with vectorized environments and network normalization.
- Speed: PQN runs without a replay buffer and target networks, ensuring significant speed-ups and sample efficiency.
- Stability: PQN can use both batch and layer normalization to stabilize training.
- Flexibility: PQN is easily compatible with RNNs, Peng's $Q(\lambda)$, and multi-agent tasks.
🔥 Quick Stats
With PQN and a single NVIDIA A40 (achieving similar performance to an RTX 3090), you can:
- 🦿 Train agents for simple tasks like CartPole and Acrobot in a few seconds.
- Train thousands of seeds in parallel in a few minutes.
- Train MinAtar in less than a minute, and complete 10 parallel seeds in less than 5 minutes.
- 🕹️ Train an Atari agent for 200M frames in one hour (with environments running on a single CPU using Envpool, tested on an AMD EPYC 7513 32-Core Processor).
- Solve simple games like Pong in a few minutes and less than 10M timesteps.
- 👾 Train a Q-Learning agent in Craftax much faster than when using a replay buffer.
- 👥 Train a strong Q-Learning baseline with VDN in multi-agent tasks.
🦾 Performances
Atari
Currently, after around 4 hours of training and 400M environment frames, PQN can achieve a median score similar to the original Rainbow paper in ALE, reaching scores higher than humans in 40 of the 57 Atari games. While this is far from the latest SOTA in ALE, it can serve as a good starting point for faster research in ALE.
<table style="width: 100%; text-align: center; border-collapse: collapse;"> <tr> <td style="width: 33.33%; vertical-align: top; padding: 10px;"> <h4>Median Score</h4> <img src="docs/atari-57_median.png" alt="Atari-57_median" width="300" style="max-width: 100%; display: block; margin: 0 auto;"/> </td> <td style="width: 33.33%; vertical-align: top; padding: 10px;"> <h4>Performance Profile</h4> <img src="docs/atari-57_tau.png" alt="Atari-57_tau" width="300" style="max-width: 100%; display: block; margin: 0 auto;"/> </td> <td style="width: 33.33%; vertical-align: top; padding: 10px;"> <h4>Training Speed</h4> <img src="docs/atari-57_speed.png" alt="Atari-57_speed" width="300" style="max-width: 100%; display: block; margin: 0 auto;"/> </td> </tr> </table>Craftax
When combined with an RNN network, PQN offers a more sample-efficient baseline compared to PPO. As an off-policy algorithm, PQN could be an interesting starting point for population-based training in Craftax!
<div style="text-align: center; margin: auto;"> <img src="docs/craftax_rnn.png" alt="craftax_rnn" width="300" style="max-width: 100%;"/> </div>Multi-Agent (JaxMarl)
When combined with Value Decomposition Networks, PQN is a strong baseline for multi-agent tasks.
<table style="width: 100%; margin: auto; text-align: center; border-collapse: collapse;"> <tr> <td style="width: 50%; vertical-align: top; padding: 10px;"> <h4>Smax</h4> <img src="docs/smax_iqm.png" alt="smax" width="300" style="max-width: 100%; display: block; margin: 0 auto;"/> </td> <td style="width: 50%; vertical-align: top; padding: 10px;"> <h4>Overcooked</h4> <img src="docs/overcooked_iqm.png" alt="overcooked" width="300" style="max-width: 100%; display: block; margin: 0 auto;"/> </td> </tr> </table>🚀 Usage (highly recommended with Docker)
Steps:
- Ensure you have Docker and the NVIDIA Container Toolkit properly installed.
- (Optional) Set your WANDB key in the Dockerfile.
- Build with
bash docker/build.sh
. - (Optional) Build the specific image for Atari (which uses different gym requirements):
bash docker/build_atari.sh
. - Run a container:
bash docker/run.sh
(for Atari:bash docker/run_atari.sh
). - Test a training script:
python purejaxql/pqn_minatar.py +alg=pqn_minatar
.
Useful commands:
# cartpole
python purejaxql/pqn_gymnax.py +alg=pqn_cartpole
# train in atari with a specific game
python purejaxql/pqn_atari.py +alg=pqn_atari alg.ENV_NAME=NameThisGame-v5
# pqn rnn with craftax
python purejaxql/pqn_rnn_craftax.py +alg=pqn_rnn_craftax
# pqn-vdn in smax
python purejaxql/pqn_vdn_rnn_jaxmarl.py +alg=pqn_vdn_rnn_smax
# Perform hyper-parameter tuning
python purejaxql/pqn_gymnax.py +alg=pqn_cartpole HYP_TUNE=True
Experiment Configuration
Check purejaxql/config/config.yaml
for the default configuration. It allows you to set up WANDB, seed, and choose the number of parallel seeds per experiment.
The algorithm-environment specific configuration files are in purejaxql/config/alg
.
Most scripts include a tune
function to perform hyperparameter tuning. You'll need to set HYP_TUNE=True
in the default config file to use it.
Citation
If you use PureJaxRL in your work, please cite the following paper:
@misc{gallici2024simplifyingdeeptemporaldifference,
title={Simplifying Deep Temporal Difference Learning},
author={Matteo Gallici and Mattie Fellows and Benjamin Ellis and Bartomeu Pou and Ivan Masmitja and Jakob Nicolaus Foerster and Mario Martin},
year={2024},
eprint={2407.04811},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2407.04811},
}
Related Projects
The following repositories are related to pure-GPU RL training: