Home

Awesome

<h1 align="center"> <a href="https://github.com/RobertTLange/gymnax/blob/main/docs/logo.png"> <img src="https://github.com/RobertTLange/gymnax/blob/main/docs/logo.png?raw=true" width="215" /></a><br> <b>Reinforcement Learning Environments in JAX 🌍</b><br> </h1> <p align="center"> <a href="https://pypi.python.org/pypi/gymnax"> <img src="https://img.shields.io/pypi/pyversions/gymnax.svg?style=flat-square" /></a> <a href= "https://badge.fury.io/py/gymnax"> <img src="https://badge.fury.io/py/gymnax.svg" /></a> <a href= "https://github.com/RobertTLange/gymnax/blob/master/LICENSE.md"> <img src="https://img.shields.io/badge/license-Apache2.0-blue.svg" /></a> <a href= "https://codecov.io/gh/RobertTLange/gymnax"> <img src="https://codecov.io/gh/RobertTLange/gymnax/branch/main/graph/badge.svg?token=OKKPDRIQJR" /></a> <a href= "https://github.com/psf/black"> <img src="https://img.shields.io/badge/code%20style-black-000000.svg" /></a> </p>

Are you fed up with slow CPU-based RL environment processes? Do you want to leverage massive vectorization for high-throughput RL experiments? gymnax brings the power of jit and vmap/pmap to the classic gym API. It supports a range of different environments including classic control, bsuite, MinAtar and a collection of classic/meta RL tasks. gymnax allows explicit functional control of environment settings (random seed or hyperparameters), which enables accelerated & parallelized rollouts for different configurations (e.g. for meta RL). By executing both environment and policy on the accelerator, it facilitates the Anakin sub-architecture proposed in the Podracer paper (Hessel et al., 2021) and highly distributed evolutionary optimization (using e.g. evosax). We provide training & checkpoints for both PPO & ES in gymnax-blines. Get started here 👉 Colab.

Basic gymnax API Usage 🍲

import jax
import gymnax

rng = jax.random.PRNGKey(0)
rng, key_reset, key_act, key_step = jax.random.split(rng, 4)

# Instantiate the environment & its settings.
env, env_params = gymnax.make("Pendulum-v1")

# Reset the environment.
obs, state = env.reset(key_reset, env_params)

# Sample a random action.
action = env.action_space(env_params).sample(key_act)

# Perform the step transition.
n_obs, n_state, reward, done, _ = env.step(key_step, state, action, env_params)

Implemented Accelerated Environments 🏎️

Environment NameReferenceSource🤖 Ckpt (Return)Secs/1M 🦶 <br /> A100 (2k 🌎)
Acrobot-v1Brockman et al. (2016)ClickPPO, ES (R: -80)0.07
Pendulum-v1Brockman et al. (2016)ClickPPO, ES (R: -130)0.07
CartPole-v1Brockman et al. (2016)ClickPPO, ES (R: 500)0.05
MountainCar-v0Brockman et al. (2016)ClickPPO, ES (R: -118)0.07
MountainCarContinuous-v0Brockman et al. (2016)ClickPPO, ES (R: 92)0.09
Asterix-MinAtarYoung & Tian (2019)ClickPPO (R: 15)0.92
Breakout-MinAtarYoung & Tian (2019)ClickPPO (R: 28)0.19
Freeway-MinAtarYoung & Tian (2019)ClickPPO (R: 58)0.87
SpaceInvaders-MinAtarYoung & Tian (2019)ClickPPO (R: 131)0.33
Catch-bsuiteOsband et al. (2019)ClickPPO, ES (R: 1)0.15
DeepSea-bsuiteOsband et al. (2019)ClickPPO, ES (R: 0)0.22
MemoryChain-bsuiteOsband et al. (2019)ClickPPO, ES (R: 0.1)0.13
UmbrellaChain-bsuiteOsband et al. (2019)ClickPPO, ES (R: 1)0.08
DiscountingChain-bsuiteOsband et al. (2019)ClickPPO, ES (R: 1.1)0.06
MNISTBandit-bsuiteOsband et al. (2019)Click--
SimpleBandit-bsuiteOsband et al. (2019)Click--
FourRooms-miscSutton et al. (1999)ClickPPO, ES (R: 1)0.07
MetaMaze-miscMicconi et al. (2020)ClickES (R: 32)0.09
PointRobot-miscDorfman et al. (2021)ClickES (R: 10)0.08
BernoulliBandit-miscWang et al. (2017)ClickES (R: 90)0.08
GaussianBandit-miscLange & Sprekeler (2022)ClickES (R: 0)0.07
Reacher-miscLenton et al. (2021)Click
Swimmer-miscLenton et al. (2021)Click
Pong-miscKirsch (2018)Click

* All displayed speeds are estimated for 1M step transitions (random policy) on a NVIDIA A100 GPU using jit compiled episode rollouts with 2000 environment workers. For more detailed speed comparisons on different accelerators (CPU, RTX 2080Ti) and MLP policies, please refer to the gymnax-blines documentation.

Installation ⏳

The latest gymnax release can directly be installed from PyPI:

pip install gymnax

If you want to get the most recent commit, please install directly from the repository:

pip install git+https://github.com/RobertTLange/gymnax.git@main

In order to use JAX on your accelerators, you can find more details in the JAX documentation.

Examples 📖

Key Selling Points 💵

Resources & Other Great Tools 📝

Acknowledgements & Citing gymnax ✏️

If you use gymnax in your research, please cite it as follows:

@software{gymnax2022github,
  author = {Robert Tjarko Lange},
  title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
  url = {http://github.com/RobertTLange/gymnax},
  version = {0.0.4},
  year = {2022},
}

We acknowledge financial support by the Google TRC and the Deutsche Forschungsgemeinschaft (DFG, German Research Foundation) under Germany's Excellence Strategy - EXC 2002/1 "Science of Intelligence" - project number 390523135.

Development 👷

You can run the test suite via python -m pytest -vv --all. If you find a bug or are missing your favourite feature, feel free to create an issue and/or start contributing 🤗.