Home

Awesome

Whisper-AT: Noise-Robust Automatic Speech Recognizers are Also Strong Audio Event Taggers

Introduction

[Paper]

[HuggingFace Space] (Try Whisper-AT without Coding!)

[Colab Demo]

[Local Notebook Demo](for user without Colab access)

[Python Package] Downloads

<p align="center"><img src="https://github.com/YuanGongND/whisper-at/blob/main/poster.png?raw=true" alt="Illustration of Whisper-AT." width="800"/></p> <div align="center"> (Please turn on audio to listen to the sounds) <video src="https://github.com/YuanGongND/whisper-at/assets/17163494/b479320a-b7f7-4bfc-acba-087b447623bd" width="400" /> </div>

Whisper-AT is a joint audio tagging and speech recognition model. It inherits strong speech recognition ability from OpenAI Whisper, and its ASR performance is exactly the same as the original Whisper. The API interface and usage are also identical to the original OpenAI Whisper, so users can seamlessly switch from the original Whisper to Whisper-AT.

The advantage of Whisper-AT is that with minimal (less than 1%**) additional computational cost, Whisper-AT outputs general audio event labels (527-class AudioSet labels) in desired temporal resolution in addition to the ASR transcripts. This makes audio tagging much easier and faster than using a standalone audio tagging model.

Internally, Whisper-AT freezes all original Whisper parameters, and trains a Time- and Layer-wise Transformer (TL-TR) on top of the Whisper encoder representations for the audio tagging task.

To help better understand the pros and cons of this work, we have attached the anonymous reviews and our responses [here]. We thank the anonymous reviewers' invaluable comments.

** Not for all models, see the paper for details.

<hr style="border: 0; height: 1px; background-color: #e0e0e0;">

Quick Start (Run in 8 lines of code)

In shell,

pip install whisper-at

For Mac/Windows users, there is a known bug, please use the following workaround:

# install all dependencies except triton
pip install numba numpy torch tqdm more-itertools tiktoken==0.3.3
# install whisper-at without any dependency
pip install --no-deps whisper-at  

Then, in Python,

import whisper_at as whisper

audio_tagging_time_resolution = 10
model = whisper.load_model("large-v1")
result = model.transcribe("audio.mp3", at_time_res=audio_tagging_time_resolution)
# ASR Results
print(result["text"])
# Audio Tagging Results
audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=5, p_threshold=-1, include_class_list=list(range(527)))
print(audio_tag_result)

Citation

Please cite our Interspeech 2023 paper if you find this repository useful.

@inproceedings{gong_whisperat,
  author={Gong, Yuan and Khurana, Sameer and Karlinsky, Leonid and Glass, James},
  title={Whisper-AT: Noise-Robust Automatic Speech Recognizers are Also Strong Audio Event Taggers},
  year=2023,
  booktitle={Proc. Interspeech 2023}
}

For Applications

The best way to learn how to use Whisper-AT is this [Colab Tutorial]. You can skip all below if you read it. If you don't have Google Colab access (uncommon), you can use this [Local Notebook] as a substitution.

<hr style="border: 0; height: 1px; background-color: #e0e0e0;">

If you do not care how Whisper-AT is implemented, but just want to use it, you only need to read this section. This will be very simple.

Step 1. Install Whisper-AT

We intentionally do not any additional dependencies to the original Whisper. So if your environment can run the original Whisper, it can also run Whisper-AT.

Whisper-AT can be installed simply by:

pip install whisper-at

For Mac/Windows users, there is a known bug, please use the following workaround:

# install all dependencies except triton
pip install numba numpy torch tqdm more-itertools tiktoken==0.3.3
# install whisper-at without any dependency
pip install --no-deps whisper-at  

Note that following original Whisper, it also requires the command-line tool ffmpeg to be installed on your system. Please check OpenAI Whisper repo for details.

Step 2. Use as the Original Whisper

# note this is whisper"_"at not whisper-at
import whisper_at as whisper

# the only new thing in whisper-at
# specify the temporal resolution for audio tagging, 10 means Whisper-AT predict audio event every 10 seconds (hop and window=10s).
audio_tagging_time_resolution = 10

model = whisper.load_model("base")
# for large, medium, small models, we provide low-dim proj AT models to save compute.
# model = whisper.load_model("large-v1", at_low_compute=Ture)
result = model.transcribe("audio.mp3", at_time_res=audio_tagging_time_resolution)
print(result["text"])

## translation task is also supported
# result = model.transcribe("audio.mp3", task='translate', at_time_res=audio_tagging_time_resolution)
# print(result["text"])

result["text"] is the ASR output transcripts, it will be identical to that of the original Whisper and is not impacted by at_time_res, the ASR function still follows Whisper's 30 second window. at_time_res is only related to audio tagging.

Compared to the original Whisper, the only new thing is at_time_res, which is the hop and window size for Whisper-AT to predict audio events. For example, for a 60-second audio, setting at_time_res = 10 means the audio will be segmented to 6 10-second segments, and Whisper-AT will predict audio tags based on each 10-second segment, a total of 6 audio event predictions will be made. Note at_time_res must be an integer multiple of 0.4, e.g., 0.4, 0.8, ..., the default value is 10.0, which is the value we use to train the model and should lead to best performance.

Step 3. Get the Audio Tagging Output

Compared with the original Whisper, result contains a new entry called audio_tag. result['audio_tag'] is a torch tensor of shape [⌈audio_length/at_time_res⌉, 527]. For example, for a 60-second audio and at_time_res = 10, result['audio_tag'] is a tensor of shape [6, 527]. 527 is the size of the AudioSet label set, result['audio_tag'][i,j] is the (unnormalised) logits of class j of the ith segment.

If you are familiar with audio tagging and AudioSet, you can take raw result['audio_tag'] for your usage. But we also provide a tool to make it easier. You can feed the result to whisepr.parse_at_label

audio_tag_result = whisper.parse_at_label(result, language='follow_asr', top_k=5, p_threshold=-1, include_class_list=list(range(527)))
print(audio_tag_result)

# Outputs (audio tag, unnormalised logits):
# {'time': {'start': 0, 'end': 10}, 'audio tags': [('Music', 1.821943759918213), ('Speech', 0.9335958957672119)]}
# {'time': {'start': 10, 'end': 20}, 'audio tags': [('Music', 1.3550536632537842), ('Grunge', -1.3502553701400757), ('Progressive rock', -1.424593210220337), ('Punk rock', -1.5715707540512085)]}
# {'time': {'start': 20, 'end': 30}, 'audio tags': [('Music', 0.8052308559417725)]}

Input Arguments of whisper.parse_at_label:

Return: A dictionary of audio tagging results.

This makes the audio tagging result human-readable, in specified language. If not specified, whisepr.parse_at_label output label names in the same language with the ASR output. That's it!

For Research

If you are interested in the findings and experiments in our Interspeech paper Whisper-AT: Noise-Robust Automatic Speech Recognizers are Also Strong Audio Event Taggers, please check this section. We provide our code to reproduce the experiments in the paper.

The paper mainly contains two contributions:

<hr style="border: 0; height: 1px; background-color: #e0e0e0;">

Part 1. Noise-Variant Representations of Noise-Robust ASR

The most important finding of this paper is that a robust ASR actually learns a noise-variant representation; most previous work focuses on noise-invariant representations.

1.1 Whisper Feature Extraction

Since we freeze the Whisper model, in our experiments, we extract and save the Whisper features first.

There are two ways to extract Whisper features:

To facilitate reproduction, we release the ESC-50 features used for experiments [here].

1.2 Noise-Robust ASR Experiment (Figure 1 (upper))

This part of code is [here] and [here].

1.3 ESC-50 Sound Classification Experiment (Figure 1 (lower))

This part of code is [here] and [here].

1.4 Class-wise Noise-Robust ASR Experiment (Figure 2)

This part of code is [here].

We use the same noise augmentation and ESC-50 sound classification methods as above. But now to class-wise analysis. Note for each noise class, the test speech samples are same, which makes a fair comparison.

1.5 Best Whisper encoder layer for each sound class (Figure 3)

This part of code is [here].

<hr style="border: 0; height: 1px; background-color: #e0e0e0;">

Part 2. Whisper-AT Training

2.1 Whisper Feature Extraction

We save all features to disk and train TL-TR on top of it. This saves GPU usage but adds i/o cost. Please see 1.1 for how to extract feature. No matter which method you use, the representation must be in shape of [num_layer, 25, representation_dim], e.g., [32, 25, 1280] for Whisper-Large.

2.2 Time and Layer-wise Transformer (TL-TR) Model

The model code is [here].

2.3 Whisper-AT Training Recipe

The Whisper-AT training recipe is here. This contains everything needed to train Whisper-AT except the data.

The starting point is run_as_full_train.sh, which calls run.sh, which then calls traintest.py.

Hyper-parameters are:

ModelInitial LRTrain Epochs (Equivalent**)Weight Averaging
large5e-530 (3)16-30
large (low proj)1e-430 (3)16-30
medium5e-530 (3)16-30
medium (low proj)1e-430 (3)16-30
small1e-450 (5)21-50
small (low proj)1e-450 (5)21-50
base1e-450 (5)21-50
tiny1e-450 (5)21-50

** We stop each epoch when 10% iteration is done. So the equivalent epochs = 0.1 * epochs.

Training logs are also released [here].

2.4 FLOPs Calculation

The model code is [here].

<hr style="border: 0; height: 1px; background-color: #e0e0e0;">

Available Models and Audio Tagging Performance

The Whisper-AT script downloads the original OpenAI Whisper model and our AT model automatically. So you do not really need to download it manually. But in case your device does not have Internet access, here is the [links]

Model <br /> Name#ASR <br /> ParamsLanguage#AT Params <br /> (TL-TR)AS mAP <br /> (TL-TR)#AT Params <br /> (TL-TR-512)AS mAP <br /> (TL-TR-512)
large-v2 <br /> (large)1550MMultilingual40.0M41.77.2M40.3
large-v11550MMultilingual40.0M42.17.2M41.6
medium.en769MEnglish25.8M41.47.1M41.1
medium769MMultilingual25.8M40.87.1M41.2
small.en244MEnglish14.6M40.16.9M39.9
small244MMultilingual14.6M39.86.9M39.8
base.en74MEnglish6.6M37.5--
base74MMultilingual6.6M37.6--
tiny.en39MEnglish3.8M35.8--
tiny39MMultilingual3.8M36.5--

License

Whisper-AT's code and model weights are released under a BSD license, which is similar with the original OpenAI Whisper's MIT license. Commercial use is welcome.

Contact

If you have a question, please create a Github issue (preferred) or send me an email yuangong@mit.edu.