Home

Awesome

VariBAD

Code for the paper "VariBAD: A very good method for Bayes-Adaptive Deep RL via Meta-Learning" - Luisa Zintgraf, Kyriacos Shiarlis, Maximilian Igl, Sebastian Schulze, Yarin Gal, Katja Hofmann, Shimon Whiteson, published at ICLR 2020.

@inproceedings{zintgraf2020varibad,
  title={VariBAD: A Very Good Method for Bayes-Adaptive Deep RL via Meta-Learning},
  author={Zintgraf, Luisa and Shiarlis, Kyriacos and Igl, Maximilian and Schulze, Sebastian and Gal, Yarin and Hofmann, Katja and Whiteson, Shimon},
  booktitle={International Conference on Learning Representation (ICLR)},
  year={2020}}

! Important !

If you use this code with your own environments, make sure to not use np.random in them (e.g. to generate the tasks) because it is not thread safe (and it may cause duplicates across threads). Instead, use the python native random function. For an example see here.

Requirements

We use PyTorch for this code, and log results using TensorboardX.

The main requirements can be found in requirements.txt.

For the MuJoCo experiments, you need to install MuJoCo. Make sure you have the right MuJoCo version:

For mujoco131, use: gym==0.9.1 gym[mujoco]==0.9.1 mujoco-py==0.5.7

Overview

The main training loop for VariBAD can be found in metalearner.py, the models are in models/, the VAE set-up and losses are in vae.py and the RL algorithms in algorithms/.

There's quite a bit of documentation in the respective scripts so have a look there for details.

Running an experiment

To evaluate variBAD on the gridworld from the paper, run

python main.py --env-type gridworld_varibad

which will use hyperparameters from config/gridworld/args_grid_varibad.py.

To run variBAD on the MuJoCo experiments use:

python main.py --env-type cheetah_dir_varibad
python main.py --env-type cheetah_vel_varibad
python main.py --env-type ant_dir_varibad
python main.py --env-type walker_varibad

You can also run RL2 and the Oracle, just replace varibad above with the respective string. See main.py for all options.

The results will by default be saved at ./logs, but you can also pass a flag with an alternative directory using --results_log_dir /path/to/dir.

The default configs are in the config/ folder. You can overwrite any default hyperparameters using command line arguments.

Results will be written to tensorboard event files, and some visualisations will be printed every now and then.

Configs

Some comments on the flags in the config files:

Results

The MuJoCo results (smoothened learning curves) and a script to plot them can be found here.

Comments