Awesome
NU-Wave — Official PyTorch Implementation
NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling<br> Junhyeok Lee, Seungu Han @ MINDsLab Inc., SNU
Official Pytorch+Lightning Implementation for NU-Wave.<br>
Update: typo fixed lightning_model.py line 36 10
--> 20
<br>
Errata added for isca_archive and arXiv <br>
Checkpoint Contribution: Thanks to freds0, he released his checkpoint at issue#18!
Official Checkpoints for SingleSpeaker released google_drive.
Since NU-Wave 2 repo is opened, we try to handle issue on new repo.
NU-Wave 2 is accepted to Interspeech 2022! Code and checkpoints are available at github!
Requirements
- Pytorch >=1.7.0 for nn.SiLU(swish activation)<br>
- Pytorch-Lightning==1.1.6<br>
- The requirements are highlighted in requirements.txt.<br>
- We also provide docker setup Dockerfile.<br>
Preprocessing
Before running our project, you need to download and preprocess dataset to .pt
files
- Download VCTK dataset
- Remove speaker
p280
andp315
- Modify path of downloaded dataset
data:dir
inhparameter.yaml
- run
utils/wav2pt.py
python utils/wav2pt.py
Training
- Adjust
hparameter.yaml
, especiallytrain
section.
train:
batch_size: 18 # Dependent on GPU memory size
lr: 0.00003
weight_decay: 0.00
num_workers: 64 # Dependent on CPU cores
gpus: 2 # number of GPUs
opt_eps: 1e-9
beta1: 0.5
beta2: 0.999
- If you want to train with single speaker, use
VCTKSingleSpkDataset
instead ofVCTKMultiSpkDataset
for dataset indataloader.py
. And usebatch_size=1
for validation dataloader. - Adjust
data
section inhparameters.yaml
.
data:
dir: '/DATA1/VCTK/VCTK-Corpus/wav48/p225' #dir/spk/format
format: '*mic1.pt'
cv_ratio: (223./231., 8./231., 0.00) #train/val/test
- run
trainer.py
.
$ python trainer.py
- If you want to resume training from checkpoint, check parser.
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--resume_from', type =int,\
required = False, help = "Resume Checkpoint epoch number")
parser.add_argument('-s', '--restart', action = "store_true",\
required = False, help = "Significant change occured, use this")
parser.add_argument('-e', '--ema', action = "store_true",\
required = False, help = "Start from ema checkpoint")
args = parser.parse_args()
- During training, tensorboard logger is logging loss, spectrogram and audio.
$ tensorboard --logdir=./tensorboard --bind_all
Evaluation
run for_test.py
or test.py
python test.py -r {checkpoint_number} {-e:option, if ema} {--save:option}
or
python for_test.py -r {checkpoint_number} {-e:option, if ema} {--save:option}
Please check parser.
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--resume_from', type =int,
required = True, help = "Resume Checkpoint epoch number")
parser.add_argument('-e', '--ema', action = "store_true",
required = False, help = "Start from ema checkpoint")
parser.add_argument('--save', action = "store_true",
required = False, help = "Save file")
While we provide lightning style test code test.py
, it has device dependency.
Thus, we recommend to use for_test.py
.
References
This implementation uses code from following repositories:
- J.Ho's official DDPM implementation
- lucidrains' DDPM pytorch implementation
- ivanvovk's WaveGrad pytorch implementation
- lmnt-com's DiffWave pytorch implementation
This README and the webpage for the audio samples are inspired by:
- Tips for Publishing Research Code
- Audio samples webpage of DCA
- Cotatron
- Audio samples wabpage of WaveGrad
The audio samples on our webpage are partially derived from:
- VCTK dataset(0.92): 46 hours of English speech from 108 speakers.
Repository Structure
.
├── Dockerfile
├── dataloader.py # Dataloader for train/val(=test)
├── filters.py # Filter implementation
├── test.py # Test with lightning_loop.
├── for_test.py # Test with for_loop. Recommended due to device dependency of lightning
├── hparameter.yaml # Config
├── lightning_model.py # NU-Wave implementation. DDPM is based on ivanvok's WaveGrad implementation
├── model.py # NU-Wave model based on lmnt-com's DiffWave implementation
├── requirement.txt # requirement libraries
├── sampling.py # Sampling a file
├── trainer.py # Lightning trainer
├── README.md
├── LICSENSE
├── utils
│ ├── stft.py # STFT layer
│ ├── tblogger.py # Tensorboard Logger for lightning
│ └── wav2pt.py # Preprocessing
└── docs # For github.io
└─ ...
Citation & Contact
If this repository useful for your research, please consider citing!
@inproceedings{lee21nuwave,
author={Junhyeok Lee and Seungu Han},
title={{NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling}},
year=2021,
booktitle={Proc. Interspeech 2021},
pages={1634--1638},
doi={10.21437/Interspeech.2021-36}
}
If you have a question or any kind of inquiries, please contact Junhyeok Lee at jun3518@mindslab.ai