Awesome
Mastering Diverse Domains through World Models
A reimplementation of DreamerV3, a scalable and general reinforcement learning algorithm that masters a wide range of applications with fixed hyperparameters.
If you find this code useful, please reference in your paper:
@article{hafner2023dreamerv3,
title={Mastering Diverse Domains through World Models},
author={Hafner, Danijar and Pasukonis, Jurgis and Ba, Jimmy and Lillicrap, Timothy},
journal={arXiv preprint arXiv:2301.04104},
year={2023}
}
To learn more:
DreamerV3
DreamerV3 learns a world model from experiences and uses it to train an actor critic policy from imagined trajectories. The world model encodes sensory inputs into categorical representations and predicts future representations and rewards given actions.
DreamerV3 masters a wide range of domains with a fixed set of hyperparameters, outperforming specialized methods. Removing the need for tuning reduces the amount of expert knowledge and computational resources needed to apply reinforcement learning.
Due to its robustness, DreamerV3 shows favorable scaling properties. Notably, using larger models consistently increases not only its final performance but also its data-efficiency. Increasing the number of gradient steps further increases data efficiency.
Instructions
The code has been tested on Linux and Mac and requires Python 3.11+.
Docker
You can either use the provided Dockerfile
that contains instructions or
follow the manual instructions below.
Manual
Install JAX and then the other dependencies:
pip install -U -r embodied/requirements.txt
pip install -U -r dreamerv3/requirements.txt \
-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Simple training script:
python example.py
Flexible training script:
python dreamerv3/main.py \
--logdir ~/logdir/{timestamp} \
--configs crafter \
--run.train_ratio 32
To reproduce results, train on the desired task using the corresponding config,
such as --configs atari --task atari_pong
.
Tips
- All config options are listed in
configs.yaml
and you can override them as flags from the command line. - The
debug
config block reduces the network size, batch size, duration between logs, and so on for fast debugging (but does not learn a good model). - By default, the code tries to run on GPU. You can switch to CPU or TPU using
the
--jax.platform cpu
flag. - You can use multiple config blocks that will override defaults in the
order they are specified, for example
--configs crafter size50m
. - By default, metrics are printed to the terminal, appended to a JSON lines file, and written as TensorBoard summaries. Other outputs like WandB can be enabled in the training script.
- If you get a
Too many leaves for PyTreeDef
error, it means you're reloading a checkpoint that is not compatible with the current config. This often happens when reusing an old logdir by accident. - If you are getting CUDA errors, scroll up because the cause is often just an
error that happened earlier, such as out of memory or incompatible JAX and
CUDA versions. Try
--batch_size 1
to rule out an out of memory error. - Many environments are included, some of which require installing additional
packages. See the
Dockerfile
for reference. - When running on custom environments, make sure to specify the observation
keys the agent should be using via the
enc.spaces
anddec.spaces
regex patterns. - To log metrics from environments without showing them to the agent or storing
them in the replay buffer, return them as observation keys with
log_
prefix and enable logging via therun.log_keys_...
options. - To continue stopped training runs, simply run the same command line again and
make sure that the
--logdir
points to the same directory.
Disclaimer
This repository contains a reimplementation of DreamerV3 based on the open source DreamerV2 code base. It is unrelated to Google or DeepMind. The implementation has been tested to reproduce the official results on a range of environments.