Home

Awesome

Mastering Memory Tasks with World Models


M. R. Samsami*, A. Zholus*, J. Rajendran, S. Chandar. Mastering Memory Tasks with World Models. ICLR 2024, top‑1.2% (oral) <br /> [Project Website]       [Paper]       [Openreview]       [X Thread]


This repo contains the official implementation of the Recall to Imagine (R2I) algorithm, introduced in the paper Mastering Memory Tasks with World Models. R2I is a generalist, computationally efficient, model-based agent focusing on memory-intensive reinforcement learnin (RL) tasks. It stands out by demonstrating superhuman performance in complex memory domain, Memory Maze. The codebase, powered by JAX, offers an efficient framework for RL development, incorporating state space models.

<p float="center"> <img src="imgs/r2i.png" alt="Teaser" width="800"/> </p>

Table of Contents

Prerequisites

Before diving into the installation and running experiments, ensure you meet the following prerequisites:

RequirementSpecification
HardwareGPU with CUDA support
SoftwareDocker or Singularity for containerized environments, conda
Python Version3.8
StorageAt least 130GB for the Memory Maze experiment's replay buffer

Installation

Conda

We provide a self-contained conda environment for running experiments. To setup the env do

conda create -n recall2imagine python=3.8
conda activate recall2imagine
cd env && conda env update -f conda_env.yaml
pip install \
    "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html \
    dm-haiku \
    flax \
    optax==0.1.5

Though conda environment should work, the recommended way of setting up the environment is through docker or singularity containers. See below for details.

Docker & Singularity

The most reliable way to reproduce the results from the paper is through containers. To build a docker image do

cd env && docker build -t r2i .

This steps takes several mins to complete. Alternatively, you can just pull our pre-built image:

docker pull artemzholus/r2i && docker tag artemzholus/r2i r2i

If you are going to run the code inside of a singularity container, convert the docker image to singularity. Note that the code below creates a lot of temporary files in the home directory to avoid that, do

export SINGULARITY_CACHEDIR=<STORAGE_FOLDER>

If you're not sure whether you need this, just skip this step.

To convert the image, do

singularity pull docker://artemzholus/r2i

Running experiments with containers

To run experiments docker, launch an interactive docker container:

docker run --rm -it --network host -v $(pwd):/code -w /code r2i bash

This will lead to the interactive shell within a container. Run any python training script within that shell.

For singularity:

singularity shell --nv <IMAGE_PATH.sif>

where image path it the file name that was downloaded when doing singularity pull

Reproducing results from the paper

The code was designed with reproducibility in mind.

Memory Maze

TL;DR: use the following command:

python recall2imagine/train.py \
    --configs mmaze --task gym_memory_maze:MemoryMaze-9x9-v0 \
    --wdb_name memory_maze_9x9 \
    --logdir ./logs_memory_maze_9x9

By default, the script uses 1 GPU. To ensure that you match the training speed reported in our paper, consider using two GPUs (this does not affect the sample efficiency):

python recall2imagine/train.py \
    --configs mmaze --task gym_memory_maze:MemoryMaze-9x9-v0 \
    --wdb_name memory_maze_9x9 \
    --jax.train_devices 0 1 \
    --jax.policy_devices 0 \
    --logdir ./logs_memory_maze_9x9

NOTE: By default the training script uses disk-based replay buffer. In memory maze our default replay buffer size is 10M image RL steps which takes about 130GB of space. Therefore, be sure to have enough space before running, since the scipt will allocate the whole file in the beginning of the training.

<img src="imgs/mmaze.jpg" alt="memory-maze" width="800"/>

Since this experiment may take a long time, if you are a large compute cluster user, you may have a limited maximal duration of experiments. Exactly for that scenario, we designed a new replay buffer for efficient restarting of very long RL experiments. The idea is to keep two versions of the replay buffer - one in the fast, but temporary storage, the other in the slower but lifelong storage. Use the following pattern to continue your experiment after interruption:

python recall2imagine/train.py \
    --configs mmaze --task gym_memory_maze:MemoryMaze-9x9-v0 \
    --wdb_name memory_maze_9x9 \
    --jax.train_devices 0 1 \
    --jax.policy_devices 0 \
    --use_lfs True --lfs_dir <LONGTERM_STORAGE> \
    --logdir <TMP_STORAGE>

Use this command for both first run and each next one. It knows how to deal with the state. This buffer is very reliable, it has a good fault tolerance and it is the way how we conducted our original experiments in the paper.

POPGym

Use the following command:

python recall2imagine/train.py \
    --configs popgym --task gym_popgym:popgym-RepeatPreviousEasy-v0 \
    --wdb_name  popgym_repeat_previous_easy \
    --logdir ./logs_popgym_repeat_previous_easy

BSuite

(Coming soon!)

BibTeX

If you find our work useful, please cite our paper:

@inproceedings{
    samsami2024mastering,
    title={Mastering Memory Tasks with World Models},
    author={Mohammad Reza Samsami and Artem Zholus and Janarthanan Rajendran and Sarath Chandar},
    booktitle={The Twelfth International Conference on Learning Representations},
    year={2024},
    url={https://openreview.net/forum?id=1vDArHJ68h}
}

Acknowledgements

We thank Danijar Hafner for providing the DreamerV3 implementation, which this repo is based upon.