Home

Awesome

MindEye: fMRI-to-Image reconstruction & retrieval

May 22 2024: Check out our new work, MindEye2, which beats MindEye1 across all metrics!

<br>

Project page: https://medarc-ai.github.io/mindeye/

arXiv preprint: https://arxiv.org/abs/2305.18274

Installation instructions

  1. Agree to the Natural Scenes Dataset's Terms and Conditions and fill out the NSD Data Access form

  2. Download this repository: git clone https://github.com/MedARC-AI/fMRI-reconstruction-NSD.git

  3. Run setup.sh to create a conda environment that contains the packages necessary to run our scripts; activate the environment with conda activate mindeye.

cd fMRI-reconstruction-NSD/src
. setup.sh
  1. (optional) For LAION-5B retrieval you will need to additionally download pretrained checkpoints. To do this, cd into the "src" folder and run . download.sh. This will allow you to train the diffusion prior starting from a pretrained checkpoint (text-to-image diffusion prior trained from LAION-Aesthetics). We observed that using this checkpoint, rather than training the prior from scratch, significantly improved LAION-5B retrieval.

General information

This repository contains Jupyter notebooks for

  1. Training MindEye (src/Train_MindEye.ipynb)
  2. Reconstructing images from brain activity using the trained model (src/Reconstructions.ipynb)
  3. Retrieving images from brain activity either from the test set or via LAION-5B (src/Retrievals.ipynb)
  4. Evaluating reconstructions against the ground truth images according to low- and high-level image metrics (src/Reconstruction_Metrics.ipynb)

All the above Jupyter notebooks also have corresponding python (.py) files which can be run via the command-line.

This repo also contains code for mapping brain activity to the variational autoencoder of Stable Diffusion (src/train_autoencoder.py).

Pre-trained Subject 1 models

You can skip training MindEye yourself and instead run the rest of the notebooks on Subject 1 of NSD by downloading our pre-trained models available on huggingface and putting these folders containing model checkpoints inside "fMRI-reconstruction-NSD/train_logs/".

prior_257_final_subj01_bimixco_softclip_byol: CLIP ViT-L/14 hidden layer (257x768) 
prior_1x768_final_subj01_bimixco_softclip_byol: CLIP ViT-L/14 final layer (1x768)
autoencoder_subj01_4x_locont_no_reconst: Stable Diffusion VAE (low-level pipeline)

Training MindEye (high-level pipeline)

Train MindEye via Train_MindEye.py.

Various arguments can be set (see below) for training; the default is to train MindEye to the last hidden layer of CLIP ViT-L/14 using the same settings as our paper, for Subject 1 of NSD.

Trained model checkpoints will be saved inside a folder "fMRI-reconstruction-NSD/train_logs". All other outputs get saved inside "fMRI-reconstruction-NSD/src" folder.

$ python Train_MindEye.py --help
usage: Train_MindEye.py [-h] [--model_name MODEL_NAME] [--data_path DATA_PATH]
                        [--subj {1,2,5,7}] [--batch_size BATCH_SIZE]
                        [--hidden | --no-hidden]
                        [--clip_variant {RN50,ViT-L/14,ViT-B/32,RN50x64}]
                        [--wandb_log | --no-wandb_log]
                        [--resume_from_ckpt | --no-resume_from_ckpt]
                        [--wandb_project WANDB_PROJECT]
                        [--mixup_pct MIXUP_PCT] [--norm_embs | --no-norm_embs]
                        [--use_image_aug | --no-use_image_aug]
                        [--num_epochs NUM_EPOCHS] [--prior | --no-prior]
                        [--v2c | --no-v2c] [--plot_umap | --no-plot_umap]
                        [--lr_scheduler_type {cycle,linear}]
                        [--ckpt_saving | --no-ckpt_saving]
                        [--ckpt_interval CKPT_INTERVAL]
                        [--save_at_end | --no-save_at_end] [--seed SEED]
                        [--max_lr MAX_LR] [--n_samples_save {0,1}]
                        [--use_projector | --no-use_projector]
                        [--vd_cache_dir VD_CACHE_DIR]

Model Training Configuration

options:
  -h, --help            show this help message and exit
  --model_name MODEL_NAME
                        name of model, used for ckpt saving and wandb logging
                        (if enabled)
  --data_path DATA_PATH
                        Path to where NSD data is stored / where to download
                        it to
  --subj {1,2,5,7}
  --batch_size BATCH_SIZE
                        Batch size can be increased by 10x if only training
                        v2c and not diffusion prior
  --hidden, --no-hidden
                        if True, CLIP embeddings will come from last hidden
                        layer (e.g., 257x768 - Versatile Diffusion), rather
                        than final layer (default: True)
  --clip_variant {RN50,ViT-L/14,ViT-B/32,RN50x64}
                        OpenAI clip variant
  --wandb_log, --no-wandb_log
                        whether to log to wandb (default: False)
  --resume_from_ckpt, --no-resume_from_ckpt
                        if not using wandb and want to resume from a ckpt
                        (default: False)
  --wandb_project WANDB_PROJECT
                        wandb project name
  --mixup_pct MIXUP_PCT
                        proportion of way through training when to switch from
                        BiMixCo to SoftCLIP
  --norm_embs, --no-norm_embs
                        Do l2-norming of CLIP embeddings (default: True)
  --use_image_aug, --no-use_image_aug
                        whether to use image augmentation (default: True)
  --num_epochs NUM_EPOCHS
                        number of epochs of training
  --prior, --no-prior   if False, will only use CLIP loss and ignore diffusion
                        prior (default: True)
  --v2c, --no-v2c       if False, will only use diffusion prior loss (default:
                        True)
  --plot_umap, --no-plot_umap
                        Plot UMAP plots alongside reconstructions (default:
                        False)
  --lr_scheduler_type {cycle,linear}
  --ckpt_saving, --no-ckpt_saving
  --ckpt_interval CKPT_INTERVAL
                        save backup ckpt and reconstruct every x epochs
  --save_at_end, --no-save_at_end
                        if True, saves best.ckpt at end of training. if False
                        and ckpt_saving==True, will save best.ckpt whenever
                        epoch shows best validation score (default: False)
  --seed SEED
  --max_lr MAX_LR
  --n_samples_save {0,1}
                        Number of reconstructions for monitoring progress, 0
                        will speed up training
  --use_projector, --no-use_projector
                        Additional MLP after the main MLP so model can
                        separately learn a way to minimize NCE from prior loss
                        (BYOL) (default: True)
  --vd_cache_dir VD_CACHE_DIR
                        Where is cached Versatile Diffusion model; if not
                        cached will download to this path

Reconstructing from pre-trained MindEye

Now that you have pre-trained model ckpts in your "train_logs" folder, either from running Train_MindEye.py or by downloading our pre-trained Subject 1 models from huggingface, we can proceed to reconstructing images from the test set of held-out brain activity.

Reconstructions.py defaults to outputting Versatile Diffusion reconstructions as a torch .pt file, without img2img and without second-order selection (recons_per_sample=1).

$ python Reconstructions.py --help
usage: Reconstructions.py [-h] [--model_name MODEL_NAME]
                          [--autoencoder_name AUTOENCODER_NAME] [--data_path DATA_PATH]
                          [--subj {1,2,5,7}] [--img2img_strength IMG2IMG_STRENGTH]
                          [--recons_per_sample RECONS_PER_SAMPLE]
                          [--vd_cache_dir VD_CACHE_DIR]

Model Training Configuration

options:
  -h, --help            show this help message and exit
  --model_name MODEL_NAME
                        name of trained model
  --autoencoder_name AUTOENCODER_NAME
                        name of trained autoencoder model
  --data_path DATA_PATH
                        Path to where NSD data is stored (see README)
  --subj {1,2,5,7}
  --img2img_strength IMG2IMG_STRENGTH
                        How much img2img (1=no img2img; 0=outputting the low-level image
                        itself)
  --recons_per_sample RECONS_PER_SAMPLE
                        How many recons to output, to then automatically pick the best
                        one (MindEye uses 16)
  --vd_cache_dir VD_CACHE_DIR
                        Where is cached Versatile Diffusion model; if not cached will
                        download to this path

Image/Brain Retrieval (inc. LAION-5B image retrieval)

To evaluate image/brain retrieval using the NSD test set then use the Jupyter notebook Retrievals.ipynb and follow the code blocks under the "Image/Brain Retrieval" heading.

Running Retrievals.py will retrieve the top 16 nearest neighbors in LAION-5B based on the MindEye variant where brain activity is mapped to the final layer of CLIP. This is followed by second-order selection where the 16 retrieved images are converted to CLIP last hidden layer embeddings and compared to the MindEye outputs from the core model where brain activity is mapped to the last hidden layer of CLIP. The highest CLIP similarity retrieved image will be chosen, with all top-1 retrievals saved to a torch .pt file.

$ python Retrievals.py --help
usage: Retrievals.py [-h] [--model_name MODEL_NAME]
                               [--model_name2 MODEL_NAME2] [--data_path DATA_PATH]
                               [--subj {1,2,5,7}]

Model Training Configuration

options:
  -h, --help            show this help message and exit
  --model_name MODEL_NAME
                        name of 257x768 model, used for everything except LAION-5B
                        retrieval
  --model_name2 MODEL_NAME2
                        name of 1x768 model, used for LAION-5B retrieval
  --data_path DATA_PATH
                        Path to where NSD data is stored (see README)
  --subj {1,2,5,7}

Evaluating Reconstructions

After you have saved a .pt file from running Reconstructions.py or Retrievals.py, you can use Reconstruction_Metrics.py to evaluate reconstructed images using the same low- and high-level image metrics used in the paper.

$ python Reconstruction_Metrics.py --help
usage: Reconstruction_Metrics.py [-h] [--recon_path RECON_PATH]
                                 [--all_images_path ALL_IMAGES_PATH]

Model Training Configuration

options:
  -h, --help            show this help message and exit
  --recon_path RECON_PATH
                        path to reconstructed/retrieved outputs
  --all_images_path ALL_IMAGES_PATH
                        path to ground truth outputs

Training MindEye (low-level pipeline)

Run train_autoencoder.py to train the MindEye low-level model.

Before training set the train_url, val_url an meta_url variables in the python file to the relevant NSD dataset location. The training code expects the weights for VICRegL ConvNext-XL (download from here) and Stable Diffusion Image Variations Autoencoder (download from here). Note that these SD autoencoder weights have been extracted as is from the LambdaLabs SD Image Variations v2 model here.

Citation

If you make use of this work please cite both the MindEye paper and the Natural Scenes Dataset paper. Also relevant to cite our newer MindEye2 paper that improves upon MindEye1.

MindEye2

Scotti, Tripathy, Torrico, Kneeland, Chen, Narang, Santhirasegaran, Xu, Naselaris, Norman, & Abraham. MindEye2: Shared-Subject Models Enable fMRI-To-Image With 1 Hour of Data. International Conference on Machine Learning. (2024). arXiv:2403.11207

MindEye1

Scotti, Banerjee, Goode, Shabalin, Nguyen, Cohen, Dempster, Verlinde, Yundler, Weisberg, Norman, & Abraham. Reconstructing the Mind's Eye: fMRI-to-Image with Contrastive Learning and Diffusion Priors. Advances in Neural Information Processing Systems, 36. (2023). arXiv:2305.18274.

Natural Scenes Dataset

Allen, St-Yves, Wu, Breedlove, Prince, Dowdle, Nau, Caron, Pestilli, Charest, Hutchinson, Naselaris, & Kay. A massive 7T fMRI dataset to bridge cognitive neuroscience and artificial intelligence. Nature Neuroscience (2021).