Awesome
muzero-pytorch
Pytorch Implementation of MuZero : "Mastering Atari , Go, Chess and Shogi by Planning with a Learned Model" based on pseudo-code provided by the authors
Note: This implementation has just been tested on CartPole-v1 and would required modifications(in config folder
) for other environments
Installation
- Python 3.8, 3.9
-
cd muzero-pytorch pip install -r requirements.txt
Usage:
- Train:
python main.py --env CartPole-v1 --case classic_control --opr train --force
- Test:
python main.py --env CartPole-v1 --case classic_control --opr test
- Visualize results :
tensorboard --logdir=<result_dir_path>
- if
--use_wandb
was passed, you can visualize results in wandb as well.
Required Arguments | Description |
---|---|
--env | Name of the environment |
--case {atari,classic_control,box2d} | It's used for switching between different domains(default: None) |
--opr {train,test} | select the operation to be performed |
Optional Arguments | Description |
---|---|
--value_loss_coeff | Scale for value loss (default: None) |
--revisit_policy_search_rate | Rate at which target policy is re-estimated (default:None)( only valid if --use_target_model is enabled) |
--use_priority | Uses priority for data sampling in replay buffer. Also, priority for new data is calculated based on loss (default: False) |
--use_max_priority | Forces max priority assignment for new incoming data in replay buffer (only valid if --use_priority is enabled) (default: False) |
--use_target_model | Use target model for bootstrap value estimation (default: False) |
--result_dir | Directory Path to store results (defaut: current working directory) |
--no_cuda | no cuda usage (default: False) |
--no_mps | no mps (Metal Performance Shaders) usage (default: False) |
--debug | If enables, logs additional values (default:False) |
--render | Renders the environment (default: False) |
--force | Overrides past results (default: False) |
--seed | seed (default: 0) |
--num_actors | Number of actors running concurrently (default: 32) |
--test_episodes | Evaluation episode count (default: 10) |
--use_wandb | Logs console and tensorboard data on wandb (default: False) |
Note: default: None => Values are loaded from the corresponding config
Training
CartPole-v1
- Curves represents model evaluation for 5 episodes at 100 step training interval.
- Also, each curve is a mean scores over 5 runs (seeds : [0,100,200,300,400])