Awesome
MUSE-Pytorch🎨
This is a PyTorch implementation of MUSE with pre-trained checkpoints on ImageNet and CC3M.
Unlike the original cross-attention conditioning type, we employ an in-context conditioning version of MUSE and adopt the recently proposed U-ViT for its high performance in image generation. A text-to-image version of our implemented pipeline is illustrated below:
Note:
-
Due to computational constraints, the released models are notably undertrained. Nonetheless, they can already achieve satisfactory performance, and we release them to facilitate community research. One can also resume training from the released checkpoints for better results.
-
The core functionality of MUSE is implemented in the
MUSE
class (which is only ~60 lines) inlibs/muse.py
.
Pretrained Models
The pre-trained models are released in 🤗HuggingFace, the detailed information is shown below:
Dataset | Model | #Params | #Training iterations | Batch size | FID |
---|---|---|---|---|---|
ImageNet 256x256 | U-ViT-B (depth=13, width=768) | 102M | 450K | 2048 | 3.84 (12 steps) |
CC3M | U-ViT-Huge (depth=29, width=1152) | 501M | 285K | 2048 | 6.84 (18 steps) |
Dependencies
conda install pytorch torchvision torchaudio cudatoolkit=11.3
pip install accelerate==0.12.0 absl-py ml_collections einops wandb ftfy==6.1.1 transformers==4.23.1 loguru webdataset==0.2.5
Data Preparation
First, download VQGAN from this link (from MAGE, thanks!), and put the downloaded VQGAN in assets/vqgan_jax_strongaug.ckpt
.
- ImageNet 256x256: Extract ImageNet features by running:
python extract_imagenet_feature.py your/imagenet/path
- CC3M:
- First, prepare some context features for training by running
python extract_test_prompt_feature.py
andpython extract_empty_feature.py
- Next, prepare the webdataset
feature2webdataset.py
- First, prepare some context features for training by running
Training & Evaluation
Download the reference statistics for FID from this link.
Place the downloaded .npz file in assets/fid_stats
.
Next, download the pre-trained checkpoints from this link to assets/ckpts
for evaluation or to continue training for more iterations.
ImageNet 256x256 (class-conditional)
# export EVAL_CKPT="assets/ckpts/imagenet256-450000.ckpt" # uncomment this to perform evaluation. Otherwise, perform training.
export OUTPUT_DIR="output_dir/for/this/experiment"
mkdir -p $OUTPUT_DIR
accelerate launch --num_processes 8 --mixed_precision fp16 train_t2i_discrete_muse.py \
--config=configs/imagenet256_base_vq_jax.py
Expected evaluation results:
step=450000 fid50000=3.8392620678172307
CC3M (text-to-image)
# export EVAL_CKPT="assets/ckpts/cc3m-285000.ckpt" # uncomment this to perform evaluation. Otherwise, perform training.
export OUTPUT_DIR="output_dir/for/this/experiment"
mkdir -p $OUTPUT_DIR
accelerate launch --num_processes 8 --mixed_precision fp16 train_t2i_discrete_wds.py \
--config=configs/cc3m_xl_vqf16_jax_2048bs_featset_CLIP_G.py
Expected evaluation results:
step=285000 fid30000=6.835978505261096
The generated images are stored in OUTPUT_DIR/eval_samples
. Each time the script is executed, a sub-directory with timestamp will be created to store the generated images.
This implementation is based on
Thanks to all authors for their work!