Home

Awesome

<h1 align="center">JaxIRL</h1> <p align="center"> <img src="https://img.shields.io/badge/python-3.8_%7C_3.9-blue" /> <a href= "https://github.com/psf/black"> <img src="https://img.shields.io/badge/code%20style-black-000000.svg" /></a> <a href= "https://github.com/FLAIROx/jaxirl/blob/main/LICENSE"> <img src="https://img.shields.io/badge/license-Apache2.0-blue.svg" /></a> </p>

Installation | Setup | Algorithms | Citation

Inverse Reinforcement Learning in JAX

Contains JAX implementation of algorithms for inverse reinforcement learning (IRL). Inverse RL is an online approach to imitation learning where we try to extract a reward function that makes the expert optimal. IRL doesn't suffer from compounding errors (like behavioural cloning) and doesn't need expert actions to train (only example trajectories of states). Depending on the environment and hyperparameters, our implementation is about ๐Ÿ”ฅ 100x ๐Ÿ”ฅ faster than standard IRL implementations in PyTorch (e.g. 3.5 minutes to train a single hopper agent โšก). By running multiple agents in parallel, you can be even faster! (e.g. 200 walker agents can be trained in ~400 minutes on 1 GPU! That's 2 minutes per agent โšกโšก).

<div class="collage"> <div class="column" align="center"> <div class="row" align="center"> <img src="https://github.com/FLAIROx/jaxirl/blob/main/plots/hopper.png" alt="Hopper" width="40%"> <img src="https://github.com/FLAIROx/jaxirl/blob/main/plots/walker2d.png" alt="walker" width="40%"> </div> <div class="row" align="center"> <img src="https://github.com/FLAIROx/jaxirl/blob/main/plots/ant.png" alt="ant" width="40%"> <img src="https://github.com/FLAIROx/jaxirl/blob/main/plots/halfcheetah.png" alt="halfcheetah" width="40%"> </div> </div> </div>

A game-theoretic perspective on IRL

<img src="https://github.com/FLAIROx/jaxirl/blob/main/plots/irl.png" align="left" alt="IRL" width="10%">

IRL is commonly framed as a two-player zero-sum game between a policy player and a reward function player. Intuitively, the reward function player tries to pick out differences between the current learner policy and the expert, while the policy player attempts to maximise this reward function to move closer to expert behaviour. This setup is effectively a GAN in the trajectory space, where the reward player is the Discriminator and the policy player is a Generator.

<br/><br/><br/>

Why JAX?

JAX is a game-changer in the world of machine learning, empowering researchers and developers to train models with unprecedented efficiency and scalability. Here's how it sets a new standard for performance:

All our code can be used with jit, vmap, pmap and scan inside other pipelines. This allows you to:

Running Experiments

The experts are already provided, but to re-run them, simply delete the corresponding expert file and they will be automatically retrained. The default configs for the experts are in jaxirl/configs/inner_training_configs.py. To change the default configs for the IRL training, change jaxirl/configs/outer_training_configs.py.

To train an IRL agent, run:

python jaxirl/irl/main.py --loss loss_type --env env_name

This package supports training via:

We support the following brax environments:

and classic control environments:

Setup

The high-level structure of this repository is as follows:

โ”œโ”€โ”€ jaxirl  # package folder
โ”‚   โ”œโ”€โ”€ configs # standard configs for inner and outer loop
โ”‚   โ”œโ”€โ”€ envs # extra envs
โ”‚   โ”œโ”€โ”€ irl # main scripts that implement Imitation Learning and IRL algorithms
โ”‚   โ”œโ”€โ”€ โ”œโ”€โ”€ bc.py # Code for standard Behavioural Cloning, called when loss_type = BC
โ”‚   โ”œโ”€โ”€ โ”œโ”€โ”€ irl.py # Code implementing basic IRL algorithm, called when loss_type = IRL
โ”‚   โ”œโ”€โ”€ โ”œโ”€โ”€ gail_discriminator.py # Used by irl.py to implement IRL algorithm
โ”‚   โ”œโ”€โ”€ โ”œโ”€โ”€ main.py # Main script to call to execute all algorithms
โ”‚   โ”œโ”€โ”€ โ”œโ”€โ”€ rl.py # Code use to train basic RL agent, called when loss_type = NONE
|   โ”œโ”€โ”€ training # generated expert demos
โ”‚   โ”œโ”€โ”€ โ”œโ”€โ”€ ppo_v2_cont_irl.py # PPO implementation for continuous action envs
โ”‚   โ”œโ”€โ”€ โ”œโ”€โ”€ ppo_v2_irl.py # PPO implementation for discrete action envs
โ”‚   โ”œโ”€โ”€ โ”œโ”€โ”€ supervised.py # Standard supervised training implementation
โ”‚   โ”œโ”€โ”€ โ”œโ”€โ”€ wrappers.py # Utility wrappers for training
โ”‚   โ”œโ”€โ”€ utils # utility functions
โ”œโ”€โ”€ experts # expert policies
โ”œโ”€โ”€ experts_test # expert policies for test version of environment

Install

conda create -n jaxirl python=3.10.8
conda activate jaxirl
pip install -r requirements.txt
pip install -e .
export PYTHONPATH=jaxirl:$PYTHONPATH

[!IMPORTANT] All scripts should be run from under jaxirl/.

Algorithms

Our IRL implementation is the moment matching version. This includes implementation tricks to make learning more stable, including decay on the discriminator and learner learning rates and gradient penalties on the discriminator.

Reproduce Results

Simply run

python3 jaxirl/irl/main.py --env env_name --loss IRL -sd 1

and the default parameters in outer_training_configs.py and the trained experts in experts/ will be used.

Citation

If you find this code useful in your research, please cite:

@misc{sapora2024evil,
      title={EvIL: Evolution Strategies for Generalisable Imitation Learning}, 
      author={Silvia Sapora and Gokul Swamy and Chris Lu and Yee Whye Teh and Jakob Nicolaus Foerster},
      year={2024},
}

See Also ๐Ÿ™Œ

Our work reused code, tricks and implementation details from the following libraries, we encourage you to take a look!