Awesome
NU-Wave2 — Official PyTorch Implementation
NU-Wave 2: A General Neural Audio Upsampling Model for Various Sampling Rates<br> Seungu Han, Junhyeok Lee @ MINDsLab Inc., SNU
Official Pytorch+Lightning Implementation for NU-Wave 2.
Official Checkpoint can be downloaded from here.
We add some additional samples for non-English voice (Korean) and ablation study without BSFT on the demo page. Please check it!
We also trained a model targeting 16 kHz (3.2 kHz ~ 16 kHz source). The Checkpoint can be downloaded from here.
Requirements
- Pytorch >=1.7.0 for nn.SiLU(swish activation)
- Pytorch-Lightning==1.2.10
- The requirements are highlighted in requirements.txt.
- We also provide docker setup Dockerfile.
Clone our Repository
git clone --recursive https://github.com/mindslab-ai/nuwave2.git
cd nuwave2
Preprocessing
Before running our project, you need to download and preprocess dataset to .wav
files
- Download VCTK dataset
- Remove speaker
p280
andp315
- Modify path of downloaded dataset
data:base_dir
inhparameter.yaml
- run
utils/flac2wav.py
python utils/flac2wav.py
Training
- Adjust
hparameter.yaml
, especiallytrain
section.
train:
batch_size: 12 # Dependent on GPU memory size
lr: 2e-4
weight_decay: 0.00
num_workers: 8 # Dependent on CPU cores
gpus: 2 # number of GPUs
opt_eps: 1e-9
beta1: 0.9
beta2: 0.99
- Adjust
data
section inhparameters.yaml
.
data:
timestamp_path: 'vctk-silence-labels/vctk-silences.0.92.txt'
base_dir: '/DATA1/VCTK-0.92/wav48_silence_trimmed/'
dir: '/DATA1/VCTK-0.92/wav48_silence_trimmed_wav/' #dir/spk/format
format: '*mic1.wav'
cv_ratio: (100./108., 8./108., 0.00) #train/val/test
- run
trainer.py
.
$ python trainer.py
- If you want to resume training from checkpoint, check parser.
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--resume_from', type =int,\
required = False, help = "Resume Checkpoint epoch number")
parser.add_argument('-s', '--restart', action = "store_true",\
required = False, help = "Significant change occured, use this")
parser.add_argument('-e', '--ema', action = "store_true",\
required = False, help = "Start from ema checkpoint")
args = parser.parse_args()
- During training, tensorboard logger is logging loss, spectrogram and audio.
$ tensorboard --logdir=./tensorboard --bind_all
Evaluation
run for_test.py
python for_test.py -r {checkpoint_number} {-e:option, if ema} {--save:option}
Please check parser.
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--resume_from', type =int,
required = True, help = "Resume Checkpoint epoch number")
parser.add_argument('-e', '--ema', action = "store_true",
required = False, help = "Start from ema checkpoint")
parser.add_argument('--save', action = "store_true",
required = False, help = "Save file")
parser.add_argument('--sr', type=int, \
required=True, help="input sampling rate")
Inference
- run
inference.py
python inference.py -c {checkpoint_path} -i {input audio} --sr {Sampling rate of input audio} {--steps:option} {--gt:option}
Please check parser.
Note: If your input is downsampled (12kHz, 16kHz, etc.) audio sample with a full valid frequency component based on the corresponding sampling rate, give the parser as '--sr {Sampling rate of input audio}' without '--gt' parser.
On the other hand, if you have a 48kHz audio sample with a full valid frequency component and just want to check whether the model works well, give the parser as '--sr {Sampling rate of input which you want to check}' and add '--gt' parser.
Please check this issue for more information.
parser = argparse.ArgumentParser()
parser.add_argument('-c',
'--checkpoint',
type=str,
required=True,
help="Checkpoint path")
parser.add_argument('-i',
'--wav',
type=str,
default=None,
help="audio")
parser.add_argument('--sr',
type=int,
required=True,
help="Sampling rate of input audio")
parser.add_argument('--steps',
type=int,
required=False,
help="Steps for sampling")
parser.add_argument('--gt', action="store_true",
required=False, help="Whether the input audio is 48 kHz ground truth audio.")
parser.add_argument('--device',
type=str,
default='cuda',
required=False,
help="Device, 'cuda' or 'cpu'")
References
This implementation uses code from following repositories:
- official NU-Wave pytorch implementation
- revsic's Jax/Flax implementation of Variational-DiffWave
- ivanvovk's WaveGrad pytorch implementation
- lmnt-com's DiffWave pytorch implementation
- NVlabs' SPADE pytorch implementation
- pkumivision's FFC pytorch implementation
This README and the webpage for the audio samples are inspired by:
- Tips for Publishing Research Code
- Audio samples webpage of DCA
- Cotatron
- Audio samples wabpage of WaveGrad
The audio samples on our webpage are partially derived from:
- VCTK dataset(0.92): 46 hours of English speech from 108 speakers.
- LJSpeech: a single-speaker English dataset consists of 13100 short audio clips of a female speaker reading passages from 7 non-fiction books, approximately 24 hours in total.
Repository Structure
.
|-- Dockerfile
|-- LICENSE
|-- README.md
|-- dataloader.py # Dataloader for train/val(=test)
|-- diffusion.py # DPM
|-- for_test.py # Test with for_loop.
|-- hparameter.yaml # Config
|-- inference.py # Inference
|-- lightning_model.py # NU-Wave 2 implementation.
|-- model.py # NU-Wave 2 model based on lmnt-com's DiffWave implementation
|-- requirements.txt # requirement libraries
|-- trainer.py # Lightning trainer
|-- utils
| |-- flac2wav.py # Preprocessing
| |-- stft.py # STFT layer
| `-- tblogger.py # Tensorboard Logger for lightning
|-- docs # For github.io
| |-- ...
`-- vctk-silence-labels # For trimming
|-- ...
Citation & Contact
If this repository useful for your research, please consider citing!
@article{han2022nu,
title={NU-Wave 2: A General Neural Audio Upsampling Model for Various Sampling Rates},
author={Han, Seungu and Lee, Junhyeok},
journal={arXiv preprint arXiv:2206.08545},
year={2022}
}
If you have a question or any kind of inquiries, please contact Seungu Han at hansw032@snu.ac.kr