Home

Awesome

transformerXL_PPO_JAX

This repository provides a JAX implementation of TranformerXL with PPO in a RL setup following : "Stabilizing Transformers for Reinforcement Learning" from Parisotto et al. (https://arxiv.org/abs/1910.06764).

The code uses the PureJaxRL template for PPO and copied some of the code from Huggingface transformerXL transferring it to JAX. We also took inspiration from the pytorch code in https://github.com/MarcoMeter/episodic-transformer-memory-ppo, which has some simplification of gradient propagation and positional encoding compared to transformerXL as it is described in the original paper (https://arxiv.org/abs/1901.02860).

The training handles Gymnax environment.

We also tested it on Craftax, on which it beat the baseline presented in the paper (https://arxiv.org/abs/2402.16801) including PPO-RNN, training with unsupervised environment design and intrinsic motivation. Notably we reach the 3rd level (the sewer) and obtain several advanced advancements, which was not achieved by the methods presented in the paper. See Craftax Results for more informations.

The training of a 5M transformer on craftax for 1e9 steps (with 1024 environments) takes about 6h30 on a single A100.

Installation

git clone git@github.com:Reytuag/transformerXL_PPO_JAX.git
cd transformerXL_PPO_JAX
pip install -r requirements.txt

:warning: By default, this will install the cpu version of JAX. You can install the GPU version of JAX following https://jax.readthedocs.io/en/latest/installation.html.

Training

You can edit the training config in train_PPO_trXL.py ( or train_PPOtrXL_pmap.py if you want to go multi GPU) including the name of the environment. (you can put any gymnax environment name, or "craftax" which will use the CraftaxSymbolic env)

To launch the training:

python3 train_PPO_trXL.py

Or if you go multi GPU.(it will use all your GPU)

python3 train_PPO_trXL_pmap.py

Results on Craftax

enter_sewerb

Without much parameter search, with a budget of 1e9 timesteps, the normalized return (% max) achieve 18.3% compared to 15.3% for PPO-RNN according to the craftax paper. (with one seed visiting the sewer). (Note that in the meantime a small error was fixed in the code, performances seem to remain similar with the fix but if you want to reproduce the exact same graph as below see commit (the error was leading to a single MLP layer in the position-wise MLP part of the transformer instead of 2 and layer norm))

The config used can be found as the default config in train_PPO_trXL_pmap.py. The results can be found in results_craftax (and can be loaded with jnp.load(str(seed)+"_metrics.npy",allow_pickle=True).item()) as well as the trained parameters.

craftax_training_transfoXL_PPO

Here are the achievements success rates across training for 1e9 steps. Notably "enter the gnomish mine" is much higher than what is reported in the craftax paper, even PPO-RNN trained on 10e9 steps so 10 times more ends up not visiting the gnomish mines while one seed luckily visit the level after the gnomish mine: the sewer.

craftax_achievements_1e9steps

With a budget of 4e9 timesteps, the normalized return is 20.6 %. Visiting the 3rd floor (the sewer) a decent amount of time and achieve several advanced achievements. Both visiting the 3rd floor and reaching any advanced achievement were not reached by any of the baseline in the craftax paper even PPO-RNN with 10e9 interactions with the environment.

Here are the achievements success rates across training for 4e9 steps:

craftax_achievements_4e9steps

However training for 8e9 steps did not lead to significant improvement. Though we did not conducted much hyperparameters tuning.

Config parameters :

:warning: WINDOW_GRAD must divide NUM_STEPS.

For continuous action space, you can follow PurejaxRL example and replace the categorical distrib in the actor network with "pi = distrax.MultivariateNormalDiag(actor_mean, jnp.exp(actor_logtstd))"

Related Works

Next steps

Citation

@softwareversion{hamon:hal-04659863v1,
  TITLE = {{transformerXL\_PPO\_JAX}},
  AUTHOR = {Hamon, Gautier},
  URL = {https://inria.hal.science/hal-04659863},
  NOTE = {},
  YEAR = {2024},
  MONTH = Jul,
  REPOSITORY = {https://github.com/Reytuag/transformerXL_PPO_JAX},
  LICENSE = {MIT License},
  KEYWORDS = {Transformer ; Reinforcement Learning ; JAX},
  HAL_ID = {hal-04659863},
  HAL_VERSION = {v1},
}