Awesome
DeMansia
About
DeMansia is a model that integrates ViM with token labeling techniques to enhance performance in image classification tasks.
Installation
We provided a simple setup.sh to install the Conda environment. You need to satisfy the following prerequisite:
- Linux
- NVIDIA GPU
- CUDA 11.8+ supported GPU driver
- Miniforge
Then, simply run source ./setup.sh
to get started.
Pretrained Models
These models were trained on the ImageNet-1k dataset using a single RTX 6000 Ada during our experiments.
Currently, only DeMansia Tiny is available. We will release more models as opportunities arise.
Name | Model Dim. | Num. of Layers | Num. of Param. | Input Res. | Top-1 | Top-5 | Batch Size | Download | Training Log |
---|---|---|---|---|---|---|---|---|---|
DeMansia Tiny | 192 | 24 | 8.06M | 224² | 79.37% | 94.51% | 768 | link | log |
Training and inferencing
To set up the ImageNet-1k dataset, download both the training and validation sets. Use this script to extract and organize the dataset. You should also download and extract the token labeling dataset from here.
We provide DeMansia train.ipynb, which contains all the necessary code to train a DeMansia model and log the training progress. The logged parameters can be modified in model.py.
The base model's hyperparameters are stored in model_config.py, and you can adjust them as needed. When further training our model, note that all hyperparameters are saved directly in the model file. For more information, refer to PyTorch Lightning's documentation. The same applies to inferencing, as PyTorch Lightning automatically handles all parameters when loading our model.
Here's a sample code snippet to perform inferencing with DeMansia:
import torch
from model import DeMansia
model = DeMansia.load_from_checkpoint("path_to.ckpt")
model.eval()
sample = torch.rand(3, 224, 224) # Channel, Width, Height
sample = sample.unsqueeze(0) # Batch, Channel, Width, Height
pred = model(sample) # Batch, # of class
Credits
Our work builds upon the remarkable achievements of Mamba, ViM and LV-ViT.
custom_mamba/ is taken from the ViM's repo <3.
module/data and module/loss are modified from the LV-ViT repo.
module/ema is modified from here.