Home

Awesome

<div align="center">

Slot Attention Lightning

<a href="https://pytorch.org/get-started/locally/"><img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white"></a> <a href="https://pytorchlightning.ai/"><img alt="Lightning" src="https://img.shields.io/badge/-Lightning-792ee5?logo=pytorchlightning&logoColor=white"></a> <a href="https://hydra.cc/"><img alt="Config: Hydra" src="https://img.shields.io/badge/Config-Hydra-89b8cd"></a> <a href="https://github.com/ashleve/lightning-hydra-template"><img alt="Template" src="https://img.shields.io/badge/-Lightning--Hydra--Template-017F2F?style=flat&logo=github&labelColor=gray"></a><br>

</div> <br>

Description

This repo is the implementation of the baseline methods for unsupervised Object-Centric Learning, including IODINE, MONet, Slot Attention, and Genesis V2. The implementation of IODINE, MONet, and Genesis V2 is from here.

<p align="center"> <img src="doc/wandb_results.png" /> </p> <p align = "center"> &#8593&#8593&#8593 Visualization of training results logged by <a href="https://docs.wandb.ai/">WandB</a> &#8593&#8593&#8593 </p> <br>

Repository Structure

<details> <summary> The directory structure of this repo looks like this: </summary> <div markdown="1">
├── .github                   <- Github Actions workflows
│
├── configs                   <- Hydra configs
│   ├── callbacks                <- Callbacks configs
│   ├── data                     <- Data configs
│   ├── debug                    <- Debugging configs
│   ├── experiment               <- *** Experiment configs ***
│   │   ├── slota                 
│   │   │  ├── clv6.yaml          
│   │   │  └── ...
│   │   └── ...                  
│   ├── extras                   <- Extra utilities configs
│   ├── hparams_search           <- Hyperparameter search configs
│   ├── hydra                    <- Hydra configs
│   ├── local                    <- Local configs
│   ├── logger                   <- Logger configs (we use wandb)
│   ├── model                    <- Model configs
│   ├── paths                    <- Project paths configs
│   ├── trainer                  <- Trainer configs
│   │
│   ├── eval.yaml             <- Main config for evaluation
│   └── train.yaml            <- Main config for training
│
├── data                      <- Directory for Dataset
│   ├── CLEVR6                
│   │   ├── images            <- raw images
│   │   │   ├── train
│   │   │   │   ├── CLEVR_train_******.png
│   │   │   │   └── ...
│   │   │   └── val
│   │   │       ├── CLEVR_val_******.png
│   │   │       └── ...
│   │   ├── masks             <- mask annotations
│   │   │   ├── train
│   │   │   │   ├── CLEVR_train_******.png
│   │   │   │   └── ...
│   │   │   └── val
│   │   │       ├── CLEVR_val_******.png
│   │   │       └── ...
│   │   └── scenes          <- metadata
│   │       ├── CLEVR_train_scenes.json
│   │       └── CLEVR_val_scenes.json
│   └── ...
│
├── logs                   <- Logs generated by hydra and lightning loggers
│
├── scripts                <- Shell scripts
│
├── src                    <- Source code
│   ├── data                     <- Data scripts
│   ├── models                   <- Model scripts
│   ├── utils                    <- Utility scripts
│   │
│   ├── eval.py                  <- Run evaluation
│   └── train.py                 <- Run training
│
├── tests                  <- Tests of any kind
│
├── .env.example              <- Example of file for storing private environment variables
├── .gitignore                <- List of files ignored by git
├── .pre-commit-config.yaml   <- Configuration of pre-commit hooks for code formatting
├── .project-root             <- File for inferring the position of project root directory
├── environment.yaml          <- File for installing conda environment
├── Makefile                  <- Makefile with commands like `make train` or `make test`
├── pyproject.toml            <- Configuration options for testing and linting
├── requirements.txt          <- File for installing python dependencies
├── setup.py                  <- File for installing project as a package
└── README.md

Note
Each dataset may have each different way of providing mask annotation and metadata, so you should match the Dataset class for each dataset with its desired configuration.

</div> </details> <br>

Installation

This repo is developed based on Lightning-Hydra-Template 1.5.3 with Python 3.8.12 and PyTorch 1.11.0.

Pip

# clone project
git clone https://github.com/janghyuk-choi/slot-attention-lightning.git
cd slot-attention-lightning

# [OPTIONAL] create conda environment
conda create -n slota python=3.8
conda activate slota

# install pytorch according to instructions
# https://pytorch.org/get-started/

# install requirements
pip install -r requirements.txt

Conda

# clone project
git clone https://github.com/janghyuk-choi/slot-attention-lightning.git
cd slot-attention-lightning

# create conda environment and install dependencies
conda env create -f environment.yaml

# activate conda environment
conda activate slota
<br>

How to run

Train model with chosen experiment configuration from configs/experiment/

Training

# training Slot Attention over CLEVR6 dataset
python src/train.py \
experiment=slota/clv6.yaml

# training Genesis V2 over CLEVRTEX dataset
python src/train.py \
experiment=genesis2/clvt.yaml

You can create your own expreiment configs for the purpose.
But, for simple modification, you can override any parameter from command line.

# training Slot Attention over CLEVR6 dataset with custom config
python src/train.py \
experiment=slota/clv6.yaml \
data.data_dir=/workspace/dataset/clevr_with_masks/CLEVR6 \
trainer.check_val_every_n_epoch=10 \
model.net.num_slots=10 \
model.net.num_iter=5 \
model.name="slota_k10_t5" # model.name will be used for logging on wandb

Evaluation

You can evaluate a trained model with the corresponding checkpoint.
The evaluation is also conducted during training with the interval of trainer.check_val_every_n_epoch.

# evaluating Slot Attention over CLEVR6 dataset.
# similar to the training phase, you can also customize the config with command line
python src/eval.py \
experiment=slota/clv6.yaml \
ckpt_path=logs/train/runs/clv6_slota/{timestamp}/checkpoints/last.ckpt