Awesome
<img src="assets/turtle.png" alt="Lego Turtle" width="50"> Turtle: Learning Truncated Causal History Model for Video Restoration [NeurIPS'2024]
The official PyTorch implementation for Learning Truncated Causal History Model for Video Restoration, accepted to NeurIPS 2024.
- Turtle achieves state-of-the-art results on multiple video restoration benchmarks, offering superior computational efficiency and enhanced restoration quality π₯π₯π₯.
- π οΈπ‘Model Forge: Easily design your own architecture by modifying the option file.
- You have the flexibility to choose from various types of layersβsuch as channel attention, simple channel attention, CHM, FHR, or custom blocksβas well as different types of feed-forward layers.
- This setup allows you to create custom networks and experiment with layer and feed-forward configurations to suit your needs.
- If you like this project, please give us a β on Github!π
π₯ π° News π₯
-
Oct. 10, 2024: The paper is now available on arxiv along with the code and pretrained models.
-
Sept 25, 2024: Turtle is accepted to NeurIPS'2024.
Table of Contents
Installation
This implementation is based on BasicSR which is an open-source toolbox for image/video restoration tasks.
python 3.9.5
pytorch 1.11.0
cuda 11.3
pip install -r requirements.txt
python setup.py develop --no_cuda_ext
Trained Models
You can download our trained models from Google Drive: Trained Models
1. Dataset Preparation
To obtain the datasets, follow the official instructions provided by each dataset's provider and download them into the dataset folder. You can download the datasets for each of the task from the following links (official sources reported by their respective authors).
- Desnowing: RSVD
- Raindrops and Rainstreaks Removal: VRDS
- Night Deraining: NightRain
- Synthetic Deblurring: GoPro
- Real-World Deblurring: BSD3ms-24ms
- Denoising: DAVIS | Set8
- Real-World Super Resolution: MVSR
The directory structure, including the ground truth ('gt') for reference frames and 'blur' for degraded images, should be organized as follows:
./datasets/
βββ Dataset_name/
βββ train/
βββ test/
βββ blur
βββ video_1
β βββ Fame1
β ....
βββ video_n
β βββ Fame1
β ....
βββ gt
βββ video_1
β βββ Fame1
β ....
βββ video_n
β βββ Fame1
β ....
2. Training
To train the model, make sure you select the appropriate data loader in the train.py
. There are two options as follows.
-
For deblurring, denoising, deraining, etc. keep the following import line, and comment the superresolution one.
from basicsr.data.video_image_dataset import VideoImageDataset
-
For superresolution, keep the following import line, and comment the previous one.
from basicsr.data.video_super_image_dataset import VideoSuperImageDataset as VideoImageDataset
python -m torch.distributed.launch --nproc_per_node=8 --master_port=8080 basicsr/train.py -opt /options/option_file_name.yml --launcher pytorch
3. Evaluation
The pretrained models can be downloaded from the GDrive link.
3.1 Testing the model
To evaluate the pre-trained model use this command:
python inference.py
Adjust the function parameters in the Python file according to each task requirements:
config
: Specify the path to the option file.model_path
: Provide the location of pre-trained model.dataset_name
: Select the dataset you are using ("RSVD", "GoPro", "SR", "NightRain", "DVD", "Set8").task_name
: Choose the restoration task ("Desnowing", "Deblurring", "SR", "Deraining", "Denoising").model_type
: Indicate the model type ("t0", "t1", "SR").save_image
: Set toTrue
if you want to save the output images; provide the output path inimage_out_path
.do_patches
: Enable if processing images in patches; adjusttile
andtile_overlap
as needed, default values are 320 and 128.y_channel_PSNR
: Enable if need to calculate PSNR/SSIM in Y Channel, default is set to False.
3.2 Running Turtle on Custom Videos:
This pipeline processes a video by extracting frames and running a pre-trained model for tasks like desnowing:
Step 1: Extract Frames from Video
-
Edit
video_to_frames.py
:- Set the
video_path
to your input video file. - Set the
output_folder
to save extracted frames.
- Set the
-
Run the script:
python video_to_frames.py
Step 2: Run Model Inference
- Edit
inference_no_ground_truth.py
:- Set paths for
config
,model_path
,data_dir
(extracted frames), andimage_out_path
(output frames).
- Set paths for
- Run the script:
python inference_no_ground_truth.py
4. Model complexity and inference speed
- To get the parameter count, MAC, and inference speed use this command:
python basicsr/models/archs/turtle_arch.py
5. Acknowledgments
This codebase borrows from the following BasicSR and ShiftNet repositories.
6. Citation
If you find our work useful, please consider citing our paper in your research.
BibTeX