Home

Awesome

Diffusion for World Modeling: Visual Details Matter in Atari (NeurIPS 2024 Spotlight)

[TL;DR] πŸ’Ž DIAMOND (DIffusion As a Model Of eNvironment Dreams) is a reinforcement learning agent trained entirely in a diffusion world model.

🌍 Project Page β€’ πŸ€“ Paper β€’ 𝕏 Atari thread β€’ 𝕏 CSGO thread β€’ πŸ’¬ Discord

<div align='center'> RL agent playing in autoregressive imagination of Atari world models <br> <img alt="DIAMOND agent in WM" src="https://github.com/user-attachments/assets/eb6b72eb-73df-4178-8a3d-cdad80ff9152"> </div> <div align='center'> Human player in CSGO world model (full quality video <a href="https://diamond-wm.github.io/static/videos/grid.mp4">here</a>) <br> <img alt="DIAMOND agent in WM" src="https://github.com/user-attachments/assets/dcbdd523-ca22-46a9-bb7d-bcc52080fe00"> </div>

Quick install to try our pretrained world models using miniconda:

git clone https://github.com/eloialonso/diamond.git
cd diamond
conda create -n diamond python=3.10
conda activate diamond
pip install -r requirements.txt

For Atari (world model + RL agent)

python src/play.py --pretrained

For CSGO (world model only)

git checkout csgo
python src/play.py

And press m to take control (the policy is playing by default)!

Warning: Atari ROMs will be downloaded with the dependencies, which means that you acknowledge that you have the license to use them.

CSGO

Edit: Check out the csgo branch to try our DIAMOND's world model trained on Counter-Strike: Global Offensive!

git checkout csgo
python src/play.py

Note on Apple Silicon you must enable CPU fallback for MPS backend with PYTORCH_ENABLE_MPS_FALLBACK=1 python src/play.py

<a name="quick_links"></a>

Quick Links

<a name="try"></a>

⬆️ Try our playable diffusion world models

python src/play.py --pretrained

Then select a game, and world model and policy pretrained on Atari 100k will be downloaded from our repository on Hugging Face Hub πŸ€— and cached on your machine.

Some things you might want to try:

To adjust the sampling parameters (number of denoising steps, stochasticity, order, etc) of the trained diffusion world model, for instance to trade off sampling speed and quality, edit the section world_model_env.diffusion_sampler in the file config/trainer.yaml.

See Visualization for more details about the available commands and options.

<a name="launch"></a>

⬆️ Launch a training run

To train with the hyperparameters used in the paper on cuda:0, launch:

python src/main.py env.train.id=BreakoutNoFrameskip-v4 common.devices=0

This creates a new folder for your run, located in outputs/YYYY-MM-DD/hh-mm-ss/.

To resume a run that crashed, navigate to the fun folder and launch:

./scripts/resume.sh

<a name="configuration"></a>

⬆️ Configuration

We use Hydra for configuration management.

All configuration files are located in the config folder:

You can turn on logging to weights & biases in the wandb section of config/trainer.yaml.

Set training.model_free=true in the file config/trainer.yaml to "unplug" the world model and perform standard model-free reinforcement learning.

<a name="visualization"></a>

⬆️ Visualization

<a name="play_mode"></a>

⬆️ Play mode (default)

To visualize your last checkpoint, launch from the run folder:

python src/play.py

By default, you visualize the policy playing in the world model. To play yourself, or switch to the real environment, use the controls described below.

Controls (play mode)

(Game-specific commands will be printed on start up)

⏎   : reset environment

m   : switch controller (policy/human)
↑/↓ : imagination horizon (+1/-1)
←/β†’ : next environment [world model ←→ real env (test) ←→ real env (train)]

.   : pause/unpause
e   : step-by-step (when paused)

Add -r to toggle "recording mode" (works only in play mode). Every completed episode will be saved in dataset/rec_<env_name>_<controller>. For instance:

You can then use the "dataset mode" described in the next section to replay the stored episodes.

<a name="dataset_mode"></a>

⬆️ Dataset mode (add -d)

In the run folder, to visualize the datasets contained in the dataset subfolder, add -d to switch to "dataset mode":

python src/play.py -d

You can use the controls described below to navigate the datasets and episodes.

Controls (dataset mode)

m   : next dataset (if multiple datasets, like recordings, etc)
↑/↓ : next/previous episode
←/β†’ : next/previous timestep in episodes
PgUp: +10 timesteps
PgDn: -10 timesteps
⏎   : back to first timestep

<a name="other_options"></a>

⬆️ Other options, common to play/dataset modes

--fps FPS             Target frame rate (default 15).
--size SIZE           Window size (default 800).
--no-header           Remove header.

<a name="structure"></a>

⬆️ Run folder structure

Each new run is located at outputs/YYYY-MM-DD/hh-mm-ss/. This folder is structured as follows:

outputs/YYYY-MM-DD/hh-mm-ss/
β”‚
└─── checkpoints
β”‚   β”‚   state.pt  # full training state
β”‚   β”‚
β”‚   └─── agent_versions
β”‚       β”‚   ...
β”‚       β”‚   agent_epoch_00999.pt
β”‚       β”‚   agent_epoch_01000.pt  # agent weights only
β”‚
└─── config
β”‚   |   trainer.yaml
|
└─── dataset
β”‚   β”‚
β”‚   └─── train
β”‚   |   β”‚   info.pt
β”‚   |   β”‚   ...
|   |
β”‚   └─── test
β”‚       β”‚   info.pt
β”‚       β”‚   ...
β”‚
└─── scripts
β”‚   β”‚   resume.sh
|   |   ...
|
└─── src
|   |   main.py
|   |   ...
|
└─── wandb
    |   ...

<a name="results"></a>

⬆️ Results

The file results/data/DIAMOND.json contains the results for each game and seed used in the paper.

The DDPM code used for Section 5.1 of the paper can be found on the ddpm branch.

<a name="citation"></a>

⬆️ Citation

@inproceedings{alonso2024diffusionworldmodelingvisual,
      title={Diffusion for World Modeling: Visual Details Matter in Atari},
      author={Eloi Alonso and Adam Jelley and Vincent Micheli and Anssi Kanervisto and Amos Storkey and Tim Pearce and François Fleuret},
      booktitle={Thirty-eighth Conference on Neural Information Processing Systems}}
      year={2024},
      url={https://arxiv.org/abs/2405.12399},
}

<a name="credits"></a>

⬆️ Credits