Home

Awesome

<div align="center"> <img width=150px src="https://github.com/epignatelli/navix/assets/26899347/4168c100-f0e6-4bae-9680-2c1a82bba8a4" alt="logo"></img>

NAVIX: minigrid in JAX

CI CD PyPI version arXiv

Quickstart | Install | Performance | Examples | Docs | The JAX ecosystem | Contribute | Cite

</div>

What is NAVIX?

NAVIX is a JAX-powered reimplementation of MiniGrid. Experiments that took <ins>1 week</ins>, now take <ins>15 minutes</ins>.

200 000x speedups compared to MiniGrid and 670 Million steps/s are not just a speed improvements. They produce a whole new paradigm that grants access to experiments that were previously impossible, e.g., those taking years to run.

It changes the game.
Check out the NAVIX performance more in detail and the documentation for more information.

Key features:

The library is in active development, and we are working on adding more environments and features. If you want join the development and contribute, please open a discussion and let's have a chat!

Installation

Install JAX

Follow the official installation guide for your OS and preferred accelerator: https://github.com/google/jax#installation.

Install NAVIX

pip install navix

Or, for the latest version from source:

pip install git+https://github.com/epignatelli/navix

Performance

NAVIX improves MiniGrid both in execution speed and throughput, allowing to run more than 2048 PPO agents in parallel almost 10 times faster than a single PPO agent in the original MiniGrid.

speedup_env

NAVIX performs 2048 × 1M/49s = 668 734 693.88 steps per second (∼ 670 Million steps/s) in batch mode, while the original Minigrid implementation performs 1M/318.01 = 3 144.65 steps per second. This is a speedup of over 200 000×. throughput_ppo

Examples

You can view a full set of examples here (more coming), but here are the most common use cases.

Compiling a collection step

import jax
import navix as nx
import jax.numpy as jnp


def run(seed):
  env = nx.make('MiniGrid-Empty-8x8-v0') # Create the environment
  key = jax.random.PRNGKey(seed)
  timestep = env.reset(key)
  actions = jax.random.randint(key, (N_TIMESTEPS,), 0, env.action_space.n)

  def body_fun(timestep, action):
      timestep = env.step(action)  # Update the environment state
      return timestep, ()

  return jax.lax.scan(body_fun, timestep, actions)[0]

# Compile the entire training run for maximum performance
final_timestep = jax.jit(jax.vmap(run))(jnp.arange(1000))

Compiling a full training run

import jax
import navix as nx
import jax.numpy as jnp
from jax import random

def run_episode(seed, env, policy):
    """Simulates a single episode with a given policy"""
    key = random.PRNGKey(seed)
    timestep = env.reset(key)
    done = False
    total_reward = 0

    while not done:
        action = policy(timestep.observation)
        timestep, reward, done, _ = env.step(action)
        total_reward += reward

    return total_reward

def train_policy(policy, num_episodes):
    """Trains a policy over multiple parallel episodes"""
    envs = jax.vmap(nx.make, in_axes=0)(['MiniGrid-MultiRoom-N2-S4-v0'] * num_episodes)
    seeds = random.split(random.PRNGKey(0), num_episodes)

    # Compile the entire training loop with XLA
    compiled_episode = jax.jit(run_episode)
    compiled_train = jax.jit(jax.vmap(compiled_episode, in_axes=(0, 0, None)))

    for _ in range(num_episodes):
        rewards = compiled_train(seeds, envs, policy)
        # ... Update the policy based on rewards ...

# Hypothetical policy function
def policy(observation):
   # ... your policy logic ...
   return action

# Start the training
train_policy(policy, num_episodes=100)

Backpropagation through the environment

import jax
import navix as nx
import jax.numpy as jnp
from jax import grad
from flax import struct


class Model(struct.PyTreeNode):
  @nn.compact
  def __call__(self, x):
    # ... your NN here

model = Model()
env = nx.environments.Room(16, 16, 8)

def loss(params, timestep):
  action = jnp.asarray(0)
  pred_obs = model.apply(timestep.observation)
  timestep = env.step(timestep, action)
  return jnp.square(timestep.observation - pred_obs).mean()

key = jax.random.PRNGKey(0)
timestep = env.reset(key)
params = model.init(key, timestep.observation)

gradients = grad(loss)(params, timestep)

JAX ecosystem for RL

NAVIX is not alone and part of an ecosystem of JAX-powered modules for RL. Check out the following projects:

Join Us!

NAVIX is actively developed. If you'd like to contribute to this open-source project, we welcome your involvement! Start a discussion or open a pull request.

Please, consider starring the project if you like NAVIX!

Cite us, please!

If you use NAVIX please cite it as:

@article{pignatelli2024navix,
  title={NAVIX: Scaling MiniGrid Environments with JAX},
  author={Pignatelli, Eduardo and Liesen, Jarek and Lange, Robert Tjarko and Lu, Chris and Castro, Pablo Samuel and Toni, Laura},
  journal={arXiv preprint arXiv:2407.19396},
  year={2024}
}