Home

Awesome

Dreamer- v2 Pytorch

Pytorch implementation of Mastering Atari with Discrete World Models<br>

<p align="middle" > <img src="images/breakout.gif" title="breakout" width="200" /> <img src="images/space_invaders.gif" title="space_invaders" width="200" /> <img src="images/asterix.gif" title="asterix" width="200" /> <img src="images/seaquest.gif" title="seaquest" width="200" /> </p>

Installation

Dependencies:

I have added requirements.txt using conda list -e > requirements.txt and environment.yml using conda env export > environment.yml from my own conda environment. <br> I think it is easier to create a new conda environment(or venv etc.) and manually install the above listed few dependencies one by one.

Running experiments

  1. In tests folder, mdp.py and pomdp.py have been setup for experiments with MinAtar environments. All default hyper-parameters used are stored in a dataclass in config.py. To run dreamerv2 with default HPs on POMDP breakout and cuda :
python pomdp.py --env breakout --device cuda
  1. Experimenting on other environments(using gym-api) can be done by adding another hyper-parameter dataclass in config.py. <br>

Evaluating saved models

Trained models for all 5 games (mdp and pomdp version of each) are uploaded to the drive link: link (64 MBs)<br> Download and unzip the models inside /test directory.

Evaluate the saved model for POMDP version of breakout environment for 5 episodes, without rendering:

python eval.py --env breakout --eval_episode 5 --eval_render 0 --pomdp 1

Evaluation Results

Average evaluation score(over 50 evaluation episodes) of models saved at every 0.1 million frames. Green curves correspond to agent which have access to complete information, while red curves correspond to agents trained with partial observability.

<img src="images/eval.png" width="5000" height="400">

In freeway, the agent gets stuck in a local maxima, wherein it learns to always move forward. The reason being that it is not penalised for crashing into cars. Probably due to policy entropy regularisation, its returns drop drastically around the 1 million frame mark, and gradually improve while maintaing the policy entropy.

Training curves

All experiments were logged using wandb. Training runs for all MDP and POMDP variants of MinAtar environments can be found on the wandb project page.

Please create an issue if you find a bug or have any queries.

Code structure:

Hyper-Parameter description:

Acknowledgments

Awesome Environments used for testing:

This code is heavily inspired by the following works: